use crate::errors::{Result, TrustformersError};
use crate::plugins::{Plugin, PluginInfo, PluginLoader, PluginManager};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::sync::{Arc, RwLock};
#[derive(Debug)]
pub struct PluginRegistry {
plugins: Arc<RwLock<HashMap<String, PluginInfo>>>,
loaded: Arc<RwLock<HashMap<String, Box<dyn Plugin>>>>,
search_paths: Arc<RwLock<Vec<PathBuf>>>,
loader: Arc<PluginLoader>,
config: RegistryConfig,
}
impl PluginRegistry {
pub fn new() -> Self {
Self {
plugins: Arc::new(RwLock::new(HashMap::new())),
loaded: Arc::new(RwLock::new(HashMap::new())),
search_paths: Arc::new(RwLock::new(Vec::new())),
loader: Arc::new(PluginLoader::new()),
config: RegistryConfig::default(),
}
}
pub fn with_config(config: RegistryConfig) -> Self {
Self {
plugins: Arc::new(RwLock::new(HashMap::new())),
loaded: Arc::new(RwLock::new(HashMap::new())),
search_paths: Arc::new(RwLock::new(Vec::new())),
loader: Arc::new(PluginLoader::new()),
config,
}
}
pub fn register(&self, name: &str, info: PluginInfo) -> Result<()> {
info.validate()?;
let mut plugins = self.plugins.write().map_err(|_| {
TrustformersError::lock_error("Failed to acquire write lock".to_string())
})?;
if plugins.contains_key(name) {
return Err(TrustformersError::plugin_error(format!(
"Plugin '{}' is already registered",
name
)));
}
plugins.insert(name.to_string(), info);
Ok(())
}
pub fn unregister(&self, name: &str) -> Result<()> {
if self.is_loaded(name) {
self.unload_plugin(name)?;
}
let mut plugins = self.plugins.write().map_err(|_| {
TrustformersError::lock_error("Failed to acquire write lock".to_string())
})?;
plugins.remove(name).ok_or_else(|| {
TrustformersError::plugin_error(format!("Plugin '{}' not found", name))
})?;
Ok(())
}
pub fn get_plugin_info(&self, name: &str) -> Result<PluginInfo> {
let plugins = self.plugins.read().map_err(|_| {
TrustformersError::lock_error("Failed to acquire read lock".to_string())
})?;
plugins
.get(name)
.cloned()
.ok_or_else(|| TrustformersError::plugin_error(format!("Plugin '{}' not found", name)))
}
pub fn is_loaded(&self, name: &str) -> bool {
self.loaded.read().map(|loaded| loaded.contains_key(name)).unwrap_or(false)
}
pub fn unload_plugin(&self, name: &str) -> Result<()> {
let mut loaded = self.loaded.write().map_err(|_| {
TrustformersError::lock_error("Failed to acquire write lock".to_string())
})?;
if let Some(mut plugin) = loaded.remove(name) {
plugin.cleanup()?;
}
Ok(())
}
pub fn add_search_path<P: AsRef<Path>>(&self, path: P) {
if let Ok(mut paths) = self.search_paths.write() {
paths.push(path.as_ref().to_path_buf());
}
}
pub fn remove_search_path<P: AsRef<Path>>(&self, path: P) {
if let Ok(mut paths) = self.search_paths.write() {
paths.retain(|p| p != path.as_ref());
}
}
pub fn scan_for_plugins(&self) -> Result<usize> {
let search_paths = self
.search_paths
.read()
.map_err(|_| TrustformersError::lock_error("Failed to acquire read lock".to_string()))?
.clone();
let mut count = 0;
for path in &search_paths {
if let Ok(entries) = std::fs::read_dir(path) {
for entry in entries.flatten() {
let path = entry.path();
if path.is_file() && self.is_plugin_file(&path) {
if let Ok(info) = self.loader.load_plugin_info(&path) {
let name = info.name().to_string();
if self.register(&name, info).is_ok() {
count += 1;
}
}
}
}
}
}
Ok(count)
}
fn is_plugin_file(&self, path: &Path) -> bool {
if let Some(ext) = path.extension() {
let ext = ext.to_string_lossy().to_lowercase();
matches!(ext.as_str(), "so" | "dll" | "dylib" | "wasm")
} else {
false
}
}
pub fn validate_dependencies(&self, name: &str) -> Result<()> {
let info = self.get_plugin_info(name)?;
for dep in info.dependencies() {
if !dep.optional {
let dep_info = self.get_plugin_info(&dep.name)?;
if !dep.requirement.matches(dep_info.version()) {
return Err(TrustformersError::plugin_error(format!(
"Plugin '{}' requires '{}' {} but found {}",
name,
dep.name,
dep.requirement,
dep_info.version()
)));
}
}
}
Ok(())
}
pub fn export_config<P: AsRef<Path>>(&self, path: P) -> Result<()> {
let plugins = self.plugins.read().map_err(|_| {
TrustformersError::lock_error("Failed to acquire read lock".to_string())
})?;
let config = RegistryConfig {
plugins: plugins.clone(),
..self.config.clone()
};
let json = serde_json::to_string_pretty(&config)
.map_err(|e| TrustformersError::serialization_error(e.to_string()))?;
std::fs::write(path, json).map_err(|e| TrustformersError::io_error(e.to_string()))?;
Ok(())
}
pub fn import_config<P: AsRef<Path>>(&self, path: P) -> Result<()> {
let json = std::fs::read_to_string(path)
.map_err(|e| TrustformersError::io_error(e.to_string()))?;
let config: RegistryConfig = serde_json::from_str(&json)
.map_err(|e| TrustformersError::serialization_error(e.to_string()))?;
let mut plugins = self.plugins.write().map_err(|_| {
TrustformersError::lock_error("Failed to acquire write lock".to_string())
})?;
for (name, info) in config.plugins {
plugins.insert(name, info);
}
Ok(())
}
pub fn stats(&self) -> Result<RegistryStats> {
let plugins = self.plugins.read().map_err(|_| {
TrustformersError::lock_error("Failed to acquire read lock".to_string())
})?;
let loaded = self.loaded.read().map_err(|_| {
TrustformersError::lock_error("Failed to acquire read lock".to_string())
})?;
Ok(RegistryStats {
total_plugins: plugins.len(),
loaded_plugins: loaded.len(),
search_paths: self.search_paths.read().map(|paths| paths.len()).unwrap_or(0),
})
}
}
impl PluginManager for PluginRegistry {
fn discover_plugins(&self) -> Result<HashMap<String, PluginInfo>> {
let plugins = self.plugins.read().map_err(|_| {
TrustformersError::lock_error("Failed to acquire read lock".to_string())
})?;
Ok(plugins.clone())
}
fn is_compatible(&self, name: &str, version: &str) -> Result<bool> {
let info = self.get_plugin_info(name)?;
Ok(info.is_compatible_with("trustformers-core", version))
}
fn load_plugin(&self, name: &str) -> Result<Box<dyn Plugin>> {
{
let loaded = self.loaded.read().map_err(|_| {
TrustformersError::lock_error("Failed to acquire read lock".to_string())
})?;
if let Some(plugin) = loaded.get(name) {
return Ok(plugin.clone());
}
}
self.validate_dependencies(name)?;
let info = self.get_plugin_info(name)?;
let mut plugin = self.loader.load_plugin(&info)?;
plugin.initialize()?;
let plugin_clone = plugin.clone();
{
let mut loaded = self.loaded.write().map_err(|_| {
TrustformersError::lock_error("Failed to acquire write lock".to_string())
})?;
loaded.insert(name.to_string(), plugin);
}
Ok(plugin_clone)
}
fn list_plugins(&self) -> Vec<String> {
self.plugins
.read()
.map(|plugins| plugins.keys().cloned().collect())
.unwrap_or_default()
}
}
impl Default for PluginRegistry {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RegistryConfig {
#[serde(default)]
pub plugins: HashMap<String, PluginInfo>,
#[serde(default = "default_max_loaded")]
pub max_loaded_plugins: usize,
#[serde(default = "default_auto_discovery")]
pub auto_discovery: bool,
#[serde(default)]
pub cache_dir: Option<PathBuf>,
#[serde(default = "default_load_timeout")]
pub load_timeout_secs: u64,
}
impl Default for RegistryConfig {
fn default() -> Self {
Self {
plugins: HashMap::new(),
max_loaded_plugins: default_max_loaded(),
auto_discovery: default_auto_discovery(),
cache_dir: None,
load_timeout_secs: default_load_timeout(),
}
}
}
fn default_max_loaded() -> usize {
100
}
fn default_auto_discovery() -> bool {
true
}
fn default_load_timeout() -> u64 {
30
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RegistryStats {
pub total_plugins: usize,
pub loaded_plugins: usize,
pub search_paths: usize,
}