use crate::messages::Message;
use anyhow::Result;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{debug, info, warn};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PluginMetadata {
pub name: String,
pub version: String,
pub author: String,
pub description: String,
pub capabilities: Vec<PluginCapability>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum PluginCapability {
PreProcessMessage,
PostProcessMessage,
CustomCommands,
ModifyResults,
CustomUI,
ExternalIntegration,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PluginHook {
BeforeProcess,
AfterProcess,
BeforeSPARQL,
AfterSPARQL,
BeforeLLM,
AfterLLM,
}
#[derive(Debug, Clone)]
pub struct PluginContext {
pub session_id: String,
pub message: Option<Message>,
pub data: HashMap<String, serde_json::Value>,
}
#[async_trait]
pub trait Plugin: Send + Sync {
fn metadata(&self) -> &PluginMetadata;
async fn initialize(&mut self) -> Result<()> {
Ok(())
}
async fn execute(&self, hook: PluginHook, context: &mut PluginContext) -> Result<()>;
async fn shutdown(&mut self) -> Result<()> {
Ok(())
}
fn handles_hook(&self, hook: PluginHook) -> bool;
}
pub struct PluginManager {
plugins: Arc<RwLock<HashMap<String, Box<dyn Plugin>>>>,
}
impl PluginManager {
pub fn new() -> Self {
info!("Initialized plugin manager");
Self {
plugins: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn register_plugin(&self, plugin: Box<dyn Plugin>) -> Result<()> {
let metadata = plugin.metadata();
let plugin_name = metadata.name.clone();
info!("Registering plugin: {} v{}", plugin_name, metadata.version);
let mut plugins = self.plugins.write().await;
plugins.insert(plugin_name, plugin);
Ok(())
}
pub async fn unregister_plugin(&self, plugin_name: &str) -> Result<()> {
info!("Unregistering plugin: {}", plugin_name);
let mut plugins = self.plugins.write().await;
if let Some(mut plugin) = plugins.remove(plugin_name) {
plugin.shutdown().await?;
}
Ok(())
}
pub async fn execute_hook(&self, hook: PluginHook, context: &mut PluginContext) -> Result<()> {
debug!("Executing hook: {:?}", hook);
let plugins = self.plugins.read().await;
for (name, plugin) in plugins.iter() {
if plugin.handles_hook(hook) {
debug!("Executing plugin {} for hook {:?}", name, hook);
if let Err(e) = plugin.execute(hook, context).await {
warn!("Plugin {} failed at hook {:?}: {}", name, hook, e);
}
}
}
Ok(())
}
pub async fn list_plugins(&self) -> Vec<PluginMetadata> {
let plugins = self.plugins.read().await;
plugins.values().map(|p| p.metadata().clone()).collect()
}
pub async fn get_plugin(&self, name: &str) -> Option<PluginMetadata> {
let plugins = self.plugins.read().await;
plugins.get(name).map(|p| p.metadata().clone())
}
}
impl Default for PluginManager {
fn default() -> Self {
Self::new()
}
}
pub struct MessageLoggerPlugin {
metadata: PluginMetadata,
log_file: Option<String>,
}
impl MessageLoggerPlugin {
pub fn new(log_file: Option<String>) -> Self {
Self {
metadata: PluginMetadata {
name: "message-logger".to_string(),
version: "1.0.0".to_string(),
author: "OxiRS Team".to_string(),
description: "Logs all messages to a file".to_string(),
capabilities: vec![
PluginCapability::PreProcessMessage,
PluginCapability::PostProcessMessage,
],
},
log_file,
}
}
}
#[async_trait]
impl Plugin for MessageLoggerPlugin {
fn metadata(&self) -> &PluginMetadata {
&self.metadata
}
async fn execute(&self, hook: PluginHook, context: &mut PluginContext) -> Result<()> {
if let Some(message) = &context.message {
let log_entry = format!(
"[{:?}] Session: {}, Message: {}",
hook, context.session_id, message.content
);
if let Some(ref file) = self.log_file {
use std::fs::OpenOptions;
use std::io::Write;
let mut file = OpenOptions::new().create(true).append(true).open(file)?;
writeln!(file, "{}", log_entry)?;
} else {
info!("{}", log_entry);
}
}
Ok(())
}
fn handles_hook(&self, hook: PluginHook) -> bool {
matches!(hook, PluginHook::BeforeProcess | PluginHook::AfterProcess)
}
}
pub struct ProfanityFilterPlugin {
metadata: PluginMetadata,
blocked_words: Vec<String>,
}
impl ProfanityFilterPlugin {
pub fn new() -> Self {
Self {
metadata: PluginMetadata {
name: "profanity-filter".to_string(),
version: "1.0.0".to_string(),
author: "OxiRS Team".to_string(),
description: "Filters profanity from messages".to_string(),
capabilities: vec![PluginCapability::PreProcessMessage],
},
blocked_words: vec!["badword1".to_string(), "badword2".to_string()],
}
}
}
#[async_trait]
impl Plugin for ProfanityFilterPlugin {
fn metadata(&self) -> &PluginMetadata {
&self.metadata
}
async fn execute(&self, _hook: PluginHook, context: &mut PluginContext) -> Result<()> {
if let Some(message) = &mut context.message {
let mut content = message.content.to_string();
for word in &self.blocked_words {
content = content.replace(word, &"*".repeat(word.len()));
}
context
.data
.insert("filtered_content".to_string(), serde_json::json!(content));
}
Ok(())
}
fn handles_hook(&self, hook: PluginHook) -> bool {
matches!(hook, PluginHook::BeforeProcess)
}
}
impl Default for ProfanityFilterPlugin {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::messages::{MessageContent, MessageRole};
use chrono::Utc;
#[tokio::test]
async fn test_plugin_registration() {
let manager = PluginManager::new();
let plugin = Box::new(MessageLoggerPlugin::new(None));
manager
.register_plugin(plugin)
.await
.expect("should succeed");
let plugins = manager.list_plugins().await;
assert_eq!(plugins.len(), 1);
assert_eq!(plugins[0].name, "message-logger");
}
#[tokio::test]
async fn test_plugin_execution() {
let manager = PluginManager::new();
let plugin = Box::new(MessageLoggerPlugin::new(None));
manager
.register_plugin(plugin)
.await
.expect("should succeed");
let message = Message {
id: uuid::Uuid::new_v4().to_string(),
role: MessageRole::User,
content: MessageContent::from_text("Test message".to_string()),
timestamp: Utc::now(),
metadata: None,
thread_id: None,
parent_message_id: None,
token_count: None,
reactions: Vec::new(),
attachments: Vec::new(),
rich_elements: Vec::new(),
};
let mut context = PluginContext {
session_id: "test-session".to_string(),
message: Some(message),
data: HashMap::new(),
};
manager
.execute_hook(PluginHook::BeforeProcess, &mut context)
.await
.expect("should succeed");
}
}