use super::core::*;
use crate::error::{OptimError, Result};
use scirs2_core::numeric::Float;
use std::collections::HashMap;
use std::fmt::Debug;
use std::path::{Path, PathBuf};
use std::sync::{Mutex, RwLock};
#[derive(Debug)]
pub struct PluginRegistry {
factories: RwLock<HashMap<String, PluginRegistration>>,
search_paths: RwLock<Vec<PathBuf>>,
config: RegistryConfig,
cache: Mutex<PluginCache>,
event_listeners: RwLock<Vec<Box<dyn RegistryEventListener>>>,
}
#[derive(Debug)]
pub struct PluginRegistration {
pub factory: Box<dyn PluginFactoryWrapper>,
pub info: PluginInfo,
pub registered_at: std::time::SystemTime,
pub status: PluginStatus,
pub load_count: usize,
pub last_used: Option<std::time::SystemTime>,
}
pub trait PluginFactoryWrapper: Debug + Send + Sync {
fn create_f32(&self, config: OptimizerConfig) -> Result<Box<dyn OptimizerPlugin<f32>>>;
fn create_f64(&self, config: OptimizerConfig) -> Result<Box<dyn OptimizerPlugin<f64>>>;
fn info(&self) -> PluginInfo;
fn validate_config(&self, config: &OptimizerConfig) -> Result<()>;
fn default_config(&self) -> OptimizerConfig;
fn config_schema(&self) -> ConfigSchema;
fn supports_type(&self, datatype: &DataType) -> bool;
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PluginStatus {
Active,
Disabled,
Failed(String),
Deprecated,
Maintenance,
}
#[derive(Debug, Clone)]
pub struct RegistryConfig {
pub auto_discovery: bool,
pub validate_on_registration: bool,
pub enable_caching: bool,
pub max_cache_size: usize,
pub load_timeout: std::time::Duration,
pub enable_sandboxing: bool,
pub allowed_sources: Vec<PluginSource>,
}
#[derive(Debug, Clone)]
pub enum PluginSource {
BuiltIn,
Local(PathBuf),
Remote(String),
Package(String),
}
#[derive(Debug)]
pub struct PluginCache {
instances: HashMap<String, CachedPlugin>,
stats: CacheStats,
}
#[derive(Debug)]
pub struct CachedPlugin {
pub plugin: Box<dyn OptimizerPlugin<f64>>,
pub cached_at: std::time::SystemTime,
pub access_count: usize,
pub last_accessed: std::time::SystemTime,
}
#[derive(Debug, Default, Clone)]
pub struct CacheStats {
pub hits: usize,
pub misses: usize,
pub evictions: usize,
pub memory_used: usize,
}
pub trait RegistryEventListener: Debug + Send + Sync {
fn on_plugin_registered(&mut self, info: &PluginInfo) {}
fn on_plugin_unregistered(&mut self, name: &str) {}
fn on_plugin_loaded(&mut self, name: &str) {}
fn on_plugin_load_failed(&mut self, _name: &str, error: &str) {}
fn on_plugin_status_changed(&mut self, _name: &str, status: &PluginStatus) {}
}
#[derive(Debug, Clone, Default)]
pub struct PluginQuery {
pub name_pattern: Option<String>,
pub category: Option<PluginCategory>,
pub required_capabilities: Vec<String>,
pub data_types: Vec<DataType>,
pub version_requirements: Option<VersionRequirement>,
pub tags: Vec<String>,
pub limit: Option<usize>,
}
#[derive(Debug, Clone)]
pub struct VersionRequirement {
pub min_version: Option<String>,
pub max_version: Option<String>,
pub exact_version: Option<String>,
}
#[derive(Debug, Clone)]
pub struct PluginSearchResult {
pub plugins: Vec<PluginInfo>,
pub total_count: usize,
pub query: PluginQuery,
pub search_time: std::time::Duration,
}
impl PluginRegistry {
pub fn new(config: RegistryConfig) -> Self {
Self {
factories: RwLock::new(HashMap::new()),
search_paths: RwLock::new(Vec::new()),
config,
cache: Mutex::new(PluginCache::new()),
event_listeners: RwLock::new(Vec::new()),
}
}
pub fn global() -> &'static Self {
static INSTANCE: std::sync::OnceLock<PluginRegistry> = std::sync::OnceLock::new();
INSTANCE.get_or_init(|| {
let config = RegistryConfig::default();
let mut registry = PluginRegistry::new(config);
registry.register_builtin_plugins();
registry
})
}
pub fn register_plugin<F>(&self, factory: F) -> Result<()>
where
F: PluginFactoryWrapper + 'static,
{
let info = factory.info();
let name = info.name.clone();
if self.config.validate_on_registration {
self.validate_plugin(&factory)?;
}
let registration = PluginRegistration {
factory: Box::new(factory),
info: info.clone(),
registered_at: std::time::SystemTime::now(),
status: PluginStatus::Active,
load_count: 0,
last_used: None,
};
{
let mut factories = self.factories.write().expect("lock poisoned");
factories.insert(name.clone(), registration);
}
{
let mut listeners = self.event_listeners.write().expect("lock poisoned");
for listener in listeners.iter_mut() {
listener.on_plugin_registered(&info);
}
}
Ok(())
}
pub fn unregister_plugin(&self, name: &str) -> Result<()> {
let mut factories = self.factories.write().expect("lock poisoned");
if factories.remove(name).is_some() {
drop(factories);
let mut listeners = self.event_listeners.write().expect("lock poisoned");
for listener in listeners.iter_mut() {
listener.on_plugin_unregistered(name);
}
Ok(())
} else {
Err(OptimError::PluginNotFound(name.to_string()))
}
}
pub fn create_optimizer<A>(
&self,
name: &str,
config: OptimizerConfig,
) -> Result<Box<dyn OptimizerPlugin<A>>>
where
A: Float + Debug + Send + Sync + 'static,
{
let factories = self.factories.read().expect("lock poisoned");
let registration = factories
.get(name)
.ok_or_else(|| OptimError::PluginNotFound(name.to_string()))?;
match registration.status {
PluginStatus::Active => {}
PluginStatus::Disabled => {
return Err(OptimError::PluginDisabled(name.to_string()));
}
PluginStatus::Failed(ref error) => {
return Err(OptimError::PluginLoadError(error.clone()));
}
PluginStatus::Deprecated => {
eprintln!("Warning: Plugin '{}' is deprecated", name);
}
PluginStatus::Maintenance => {
return Err(OptimError::PluginInMaintenance(name.to_string()));
}
}
registration.factory.validate_config(&config)?;
let optimizer = if std::any::TypeId::of::<A>() == std::any::TypeId::of::<f32>() {
let opt = registration.factory.create_f32(config)?;
unsafe {
std::mem::transmute::<Box<dyn OptimizerPlugin<f32>>, Box<dyn OptimizerPlugin<A>>>(
opt,
)
}
} else if std::any::TypeId::of::<A>() == std::any::TypeId::of::<f64>() {
let opt = registration.factory.create_f64(config)?;
unsafe {
std::mem::transmute::<Box<dyn OptimizerPlugin<f64>>, Box<dyn OptimizerPlugin<A>>>(
opt,
)
}
} else {
return Err(OptimError::UnsupportedDataType(format!(
"Type {} not supported",
std::any::type_name::<A>()
)));
};
drop(factories);
let mut factories = self.factories.write().expect("lock poisoned");
if let Some(registration) = factories.get_mut(name) {
registration.load_count += 1;
registration.last_used = Some(std::time::SystemTime::now());
}
drop(factories);
let mut listeners = self.event_listeners.write().expect("lock poisoned");
for listener in listeners.iter_mut() {
listener.on_plugin_loaded(name);
}
Ok(optimizer)
}
pub fn list_plugins(&self) -> Vec<PluginInfo> {
let factories = self.factories.read().expect("lock poisoned");
factories.values().map(|reg| reg.info.clone()).collect()
}
pub fn search_plugins(&self, query: PluginQuery) -> PluginSearchResult {
let start_time = std::time::Instant::now();
let factories = self.factories.read().expect("lock poisoned");
let mut matching_plugins = Vec::new();
for registration in factories.values() {
if self.matches_query(®istration.info, &query) {
matching_plugins.push(registration.info.clone());
}
}
let total_count = matching_plugins.len();
if let Some(limit) = query.limit {
matching_plugins.truncate(limit);
}
let search_time = start_time.elapsed();
PluginSearchResult {
plugins: matching_plugins,
total_count,
query,
search_time,
}
}
pub fn get_plugin_info(&self, name: &str) -> Option<PluginInfo> {
let factories = self.factories.read().expect("lock poisoned");
factories.get(name).map(|reg| reg.info.clone())
}
pub fn get_plugin_status(&self, name: &str) -> Option<PluginStatus> {
let factories = self.factories.read().expect("lock poisoned");
factories.get(name).map(|reg| reg.status.clone())
}
pub fn set_plugin_status(&self, name: &str, status: PluginStatus) -> Result<()> {
let mut factories = self.factories.write().expect("lock poisoned");
let registration = factories
.get_mut(name)
.ok_or_else(|| OptimError::PluginNotFound(name.to_string()))?;
let old_status = registration.status.clone();
registration.status = status.clone();
if old_status != status {
drop(factories);
let mut listeners = self.event_listeners.write().expect("lock poisoned");
for listener in listeners.iter_mut() {
listener.on_plugin_status_changed(name, &status);
}
}
Ok(())
}
pub fn add_search_path<P: AsRef<Path>>(&self, path: P) {
let mut search_paths = self.search_paths.write().expect("lock poisoned");
search_paths.push(path.as_ref().to_path_buf());
}
pub fn discover_plugins(&self) -> Result<usize> {
if !self.config.auto_discovery {
return Ok(0);
}
let search_paths = self.search_paths.read().expect("lock poisoned");
let mut discovered_count = 0;
for path in search_paths.iter() {
if path.exists() && path.is_dir() {
discovered_count += self.discover_plugins_in_directory(path)?;
}
}
Ok(discovered_count)
}
pub fn add_event_listener(&self, listener: Box<dyn RegistryEventListener>) {
let mut listeners = self.event_listeners.write().expect("lock poisoned");
listeners.push(listener);
}
pub fn get_cache_stats(&self) -> CacheStats {
let cache = self.cache.lock().expect("lock poisoned");
cache.stats.clone()
}
pub fn clear_cache(&self) {
let mut cache = self.cache.lock().expect("lock poisoned");
cache.instances.clear();
cache.stats = CacheStats::default();
}
fn validate_plugin(&self, factory: &dyn PluginFactoryWrapper) -> Result<()> {
let config = factory.default_config();
let _optimizer = factory.create_f64(config)?;
Ok(())
}
fn matches_query(&self, info: &PluginInfo, query: &PluginQuery) -> bool {
if let Some(ref pattern) = query.name_pattern {
if !info.name.contains(pattern) {
return false;
}
}
if let Some(ref category) = query.category {
if info.category != *category {
return false;
}
}
if !query.data_types.is_empty() {
let has_common_type = query
.data_types
.iter()
.any(|dt| info.supported_types.contains(dt));
if !has_common_type {
return false;
}
}
if !query.tags.is_empty() {
let has_common_tag = query.tags.iter().any(|tag| info.tags.contains(tag));
if !has_common_tag {
return false;
}
}
if let Some(ref version_req) = query.version_requirements {
if !self.version_matches(&info.version, version_req) {
return false;
}
}
true
}
fn version_matches(&self, version: &str, requirement: &VersionRequirement) -> bool {
if let Some(ref exact) = requirement.exact_version {
return version == exact;
}
if let Some(ref min) = requirement.min_version {
if version < min.as_str() {
return false;
}
}
if let Some(ref max) = requirement.max_version {
if version >= max.as_str() {
return false;
}
}
true
}
fn discover_plugins_in_directory(&self, path: &Path) -> Result<usize> {
Ok(0)
}
fn register_builtin_plugins(&mut self) {
}
}
impl PluginCache {
fn new() -> Self {
Self {
instances: HashMap::new(),
stats: CacheStats::default(),
}
}
}
impl Default for RegistryConfig {
fn default() -> Self {
Self {
auto_discovery: true,
validate_on_registration: true,
enable_caching: true,
max_cache_size: 100,
load_timeout: std::time::Duration::from_secs(30),
enable_sandboxing: false,
allowed_sources: vec![
PluginSource::BuiltIn,
PluginSource::Local(PathBuf::from("./plugins")),
],
}
}
}
#[macro_export]
macro_rules! register_optimizer_plugin {
($factory:expr) => {
$crate::plugin::PluginRegistry::global().register_plugin($factory)?
};
}
pub struct PluginQueryBuilder {
query: PluginQuery,
}
impl Default for PluginQueryBuilder {
fn default() -> Self {
Self::new()
}
}
impl PluginQueryBuilder {
pub fn new() -> Self {
Self {
query: PluginQuery::default(),
}
}
pub fn name_pattern(mut self, pattern: &str) -> Self {
self.query.name_pattern = Some(pattern.to_string());
self
}
pub fn category(mut self, category: PluginCategory) -> Self {
self.query.category = Some(category);
self
}
pub fn data_type(mut self, datatype: DataType) -> Self {
self.query.data_types.push(datatype);
self
}
pub fn tag(mut self, tag: &str) -> Self {
self.query.tags.push(tag.to_string());
self
}
pub fn limit(mut self, limit: usize) -> Self {
self.query.limit = Some(limit);
self
}
pub fn build(self) -> PluginQuery {
self.query
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_plugin_registry_creation() {
let config = RegistryConfig::default();
let registry = PluginRegistry::new(config);
assert_eq!(registry.list_plugins().len(), 0);
}
#[test]
fn test_plugin_query_builder() {
let query = PluginQueryBuilder::new()
.name_pattern("adam")
.category(PluginCategory::FirstOrder)
.data_type(DataType::F32)
.limit(10)
.build();
assert_eq!(query.name_pattern, Some("adam".to_string()));
assert_eq!(query.category, Some(PluginCategory::FirstOrder));
assert_eq!(query.limit, Some(10));
}
}