optirs_core/plugin/
registry.rs

1// Plugin registry for managing and discovering optimizer plugins
2//
3// This module provides a centralized registry system for managing optimizer plugins,
4// including registration, discovery, loading, and version management.
5
6use super::core::*;
7use crate::error::{OptimError, Result};
8use scirs2_core::numeric::Float;
9use std::collections::HashMap;
10use std::fmt::Debug;
11use std::path::{Path, PathBuf};
12use std::sync::{Mutex, RwLock};
13
14/// Central plugin registry for managing all optimizer plugins
15#[derive(Debug)]
16pub struct PluginRegistry {
17    /// Registered plugin factories
18    factories: RwLock<HashMap<String, PluginRegistration>>,
19    /// Plugin search paths
20    search_paths: RwLock<Vec<PathBuf>>,
21    /// Registry configuration
22    config: RegistryConfig,
23    /// Plugin cache
24    cache: Mutex<PluginCache>,
25    /// Event listeners
26    event_listeners: RwLock<Vec<Box<dyn RegistryEventListener>>>,
27}
28
29/// Plugin registration entry
30#[derive(Debug)]
31pub struct PluginRegistration {
32    /// Plugin factory
33    pub factory: Box<dyn PluginFactoryWrapper>,
34    /// Plugin metadata
35    pub info: PluginInfo,
36    /// Registration timestamp
37    pub registered_at: std::time::SystemTime,
38    /// Plugin status
39    pub status: PluginStatus,
40    /// Load count
41    pub load_count: usize,
42    /// Last used timestamp
43    pub last_used: Option<std::time::SystemTime>,
44}
45
46/// Wrapper trait for type-erased plugin factories
47pub trait PluginFactoryWrapper: Debug + Send + Sync {
48    /// Create optimizer with f32 precision
49    fn create_f32(&self, config: OptimizerConfig) -> Result<Box<dyn OptimizerPlugin<f32>>>;
50
51    /// Create optimizer with f64 precision
52    fn create_f64(&self, config: OptimizerConfig) -> Result<Box<dyn OptimizerPlugin<f64>>>;
53
54    /// Get factory information
55    fn info(&self) -> PluginInfo;
56
57    /// Validate configuration
58    fn validate_config(&self, config: &OptimizerConfig) -> Result<()>;
59
60    /// Get default configuration
61    fn default_config(&self) -> OptimizerConfig;
62
63    /// Get configuration schema
64    fn config_schema(&self) -> ConfigSchema;
65
66    /// Check if factory supports the given data type
67    fn supports_type(&self, datatype: &DataType) -> bool;
68}
69
70/// Plugin status
71#[derive(Debug, Clone, PartialEq, Eq)]
72pub enum PluginStatus {
73    /// Plugin is active and available
74    Active,
75    /// Plugin is disabled
76    Disabled,
77    /// Plugin failed to load
78    Failed(String),
79    /// Plugin is deprecated
80    Deprecated,
81    /// Plugin is in maintenance mode
82    Maintenance,
83}
84
85/// Registry configuration
86#[derive(Debug, Clone)]
87pub struct RegistryConfig {
88    /// Enable automatic plugin discovery
89    pub auto_discovery: bool,
90    /// Enable plugin validation on registration
91    pub validate_on_registration: bool,
92    /// Enable plugin caching
93    pub enable_caching: bool,
94    /// Maximum cache size
95    pub max_cache_size: usize,
96    /// Plugin load timeout
97    pub load_timeout: std::time::Duration,
98    /// Enable plugin sandboxing (future feature)
99    pub enable_sandboxing: bool,
100    /// Allowed plugin sources
101    pub allowed_sources: Vec<PluginSource>,
102}
103
104/// Plugin source types
105#[derive(Debug, Clone)]
106pub enum PluginSource {
107    /// Built-in plugins
108    BuiltIn,
109    /// Local filesystem
110    Local(PathBuf),
111    /// Remote repository
112    Remote(String),
113    /// Package manager
114    Package(String),
115}
116
117/// Plugin cache for performance optimization
118#[derive(Debug)]
119pub struct PluginCache {
120    /// Cached plugin instances
121    instances: HashMap<String, CachedPlugin>,
122    /// Cache statistics
123    stats: CacheStats,
124}
125
126/// Cached plugin instance
127#[derive(Debug)]
128pub struct CachedPlugin {
129    /// Plugin instance
130    pub plugin: Box<dyn OptimizerPlugin<f64>>,
131    /// Cache timestamp
132    pub cached_at: std::time::SystemTime,
133    /// Access count
134    pub access_count: usize,
135    /// Last accessed
136    pub last_accessed: std::time::SystemTime,
137}
138
139/// Cache statistics
140#[derive(Debug, Default, Clone)]
141pub struct CacheStats {
142    /// Total cache hits
143    pub hits: usize,
144    /// Total cache misses
145    pub misses: usize,
146    /// Total evictions
147    pub evictions: usize,
148    /// Total memory used (bytes)
149    pub memory_used: usize,
150}
151
152/// Registry event listener trait
153pub trait RegistryEventListener: Debug + Send + Sync {
154    /// Called when a plugin is registered
155    fn on_plugin_registered(&mut self, info: &PluginInfo) {}
156
157    /// Called when a plugin is unregistered
158    fn on_plugin_unregistered(&mut self, name: &str) {}
159
160    /// Called when a plugin is loaded
161    fn on_plugin_loaded(&mut self, name: &str) {}
162
163    /// Called when a plugin fails to load
164    fn on_plugin_load_failed(&mut self, _name: &str, error: &str) {}
165
166    /// Called when a plugin is enabled/disabled
167    fn on_plugin_status_changed(&mut self, _name: &str, status: &PluginStatus) {}
168}
169
170/// Plugin search query
171#[derive(Debug, Clone, Default)]
172pub struct PluginQuery {
173    /// Plugin name pattern
174    pub name_pattern: Option<String>,
175    /// Plugin category filter
176    pub category: Option<PluginCategory>,
177    /// Required capabilities
178    pub required_capabilities: Vec<String>,
179    /// Supported data types
180    pub data_types: Vec<DataType>,
181    /// Version requirements
182    pub version_requirements: Option<VersionRequirement>,
183    /// Tags filter
184    pub tags: Vec<String>,
185    /// Maximum results
186    pub limit: Option<usize>,
187}
188
189/// Version requirement specification
190#[derive(Debug, Clone)]
191pub struct VersionRequirement {
192    /// Minimum version (inclusive)
193    pub min_version: Option<String>,
194    /// Maximum version (exclusive)
195    pub max_version: Option<String>,
196    /// Exact version match
197    pub exact_version: Option<String>,
198}
199
200/// Plugin search result
201#[derive(Debug, Clone)]
202pub struct PluginSearchResult {
203    /// Matching plugins
204    pub plugins: Vec<PluginInfo>,
205    /// Total count (before limit)
206    pub total_count: usize,
207    /// Search query used
208    pub query: PluginQuery,
209    /// Search execution time
210    pub search_time: std::time::Duration,
211}
212
213impl PluginRegistry {
214    /// Create a new plugin registry
215    pub fn new(config: RegistryConfig) -> Self {
216        Self {
217            factories: RwLock::new(HashMap::new()),
218            search_paths: RwLock::new(Vec::new()),
219            config,
220            cache: Mutex::new(PluginCache::new()),
221            event_listeners: RwLock::new(Vec::new()),
222        }
223    }
224
225    /// Get the global plugin registry instance
226    pub fn global() -> &'static Self {
227        static INSTANCE: std::sync::OnceLock<PluginRegistry> = std::sync::OnceLock::new();
228        INSTANCE.get_or_init(|| {
229            let config = RegistryConfig::default();
230            let mut registry = PluginRegistry::new(config);
231            registry.register_builtin_plugins();
232            registry
233        })
234    }
235
236    /// Register a plugin factory
237    pub fn register_plugin<F>(&self, factory: F) -> Result<()>
238    where
239        F: PluginFactoryWrapper + 'static,
240    {
241        let info = factory.info();
242        let name = info.name.clone();
243
244        // Validate plugin if enabled
245        if self.config.validate_on_registration {
246            self.validate_plugin(&factory)?;
247        }
248
249        let registration = PluginRegistration {
250            factory: Box::new(factory),
251            info: info.clone(),
252            registered_at: std::time::SystemTime::now(),
253            status: PluginStatus::Active,
254            load_count: 0,
255            last_used: None,
256        };
257
258        {
259            let mut factories = self.factories.write().unwrap();
260            factories.insert(name.clone(), registration);
261        }
262
263        // Notify event listeners
264        {
265            let mut listeners = self.event_listeners.write().unwrap();
266            for listener in listeners.iter_mut() {
267                listener.on_plugin_registered(&info);
268            }
269        }
270
271        Ok(())
272    }
273
274    /// Unregister a plugin
275    pub fn unregister_plugin(&self, name: &str) -> Result<()> {
276        let mut factories = self.factories.write().unwrap();
277        if factories.remove(name).is_some() {
278            // Notify event listeners
279            drop(factories);
280            let mut listeners = self.event_listeners.write().unwrap();
281            for listener in listeners.iter_mut() {
282                listener.on_plugin_unregistered(name);
283            }
284            Ok(())
285        } else {
286            Err(OptimError::PluginNotFound(name.to_string()))
287        }
288    }
289
290    /// Create optimizer instance from plugin
291    pub fn create_optimizer<A>(
292        &self,
293        name: &str,
294        config: OptimizerConfig,
295    ) -> Result<Box<dyn OptimizerPlugin<A>>>
296    where
297        A: Float + Debug + Send + Sync + 'static,
298    {
299        let factories = self.factories.read().unwrap();
300        let registration = factories
301            .get(name)
302            .ok_or_else(|| OptimError::PluginNotFound(name.to_string()))?;
303
304        // Check plugin status
305        match registration.status {
306            PluginStatus::Active => {}
307            PluginStatus::Disabled => {
308                return Err(OptimError::PluginDisabled(name.to_string()));
309            }
310            PluginStatus::Failed(ref error) => {
311                return Err(OptimError::PluginLoadError(error.clone()));
312            }
313            PluginStatus::Deprecated => {
314                // Log warning but continue
315                eprintln!("Warning: Plugin '{}' is deprecated", name);
316            }
317            PluginStatus::Maintenance => {
318                return Err(OptimError::PluginInMaintenance(name.to_string()));
319            }
320        }
321
322        // Validate configuration
323        registration.factory.validate_config(&config)?;
324
325        // Create optimizer based on type
326        let optimizer = if std::any::TypeId::of::<A>() == std::any::TypeId::of::<f32>() {
327            let opt = registration.factory.create_f32(config)?;
328            // This is safe because we checked the type
329            unsafe {
330                std::mem::transmute::<Box<dyn OptimizerPlugin<f32>>, Box<dyn OptimizerPlugin<A>>>(
331                    opt,
332                )
333            }
334        } else if std::any::TypeId::of::<A>() == std::any::TypeId::of::<f64>() {
335            let opt = registration.factory.create_f64(config)?;
336            // This is safe because we checked the type
337            unsafe {
338                std::mem::transmute::<Box<dyn OptimizerPlugin<f64>>, Box<dyn OptimizerPlugin<A>>>(
339                    opt,
340                )
341            }
342        } else {
343            return Err(OptimError::UnsupportedDataType(format!(
344                "Type {} not supported",
345                std::any::type_name::<A>()
346            )));
347        };
348
349        // Update usage statistics
350        drop(factories);
351        let mut factories = self.factories.write().unwrap();
352        if let Some(registration) = factories.get_mut(name) {
353            registration.load_count += 1;
354            registration.last_used = Some(std::time::SystemTime::now());
355        }
356
357        // Notify event listeners
358        drop(factories);
359        let mut listeners = self.event_listeners.write().unwrap();
360        for listener in listeners.iter_mut() {
361            listener.on_plugin_loaded(name);
362        }
363
364        Ok(optimizer)
365    }
366
367    /// List all registered plugins
368    pub fn list_plugins(&self) -> Vec<PluginInfo> {
369        let factories = self.factories.read().unwrap();
370        factories.values().map(|reg| reg.info.clone()).collect()
371    }
372
373    /// Search for plugins matching criteria
374    pub fn search_plugins(&self, query: PluginQuery) -> PluginSearchResult {
375        let start_time = std::time::Instant::now();
376        let factories = self.factories.read().unwrap();
377
378        let mut matching_plugins = Vec::new();
379
380        for registration in factories.values() {
381            if self.matches_query(&registration.info, &query) {
382                matching_plugins.push(registration.info.clone());
383            }
384        }
385
386        let total_count = matching_plugins.len();
387
388        // Apply limit if specified
389        if let Some(limit) = query.limit {
390            matching_plugins.truncate(limit);
391        }
392
393        let search_time = start_time.elapsed();
394
395        PluginSearchResult {
396            plugins: matching_plugins,
397            total_count,
398            query,
399            search_time,
400        }
401    }
402
403    /// Get plugin information
404    pub fn get_plugin_info(&self, name: &str) -> Option<PluginInfo> {
405        let factories = self.factories.read().unwrap();
406        factories.get(name).map(|reg| reg.info.clone())
407    }
408
409    /// Get plugin status
410    pub fn get_plugin_status(&self, name: &str) -> Option<PluginStatus> {
411        let factories = self.factories.read().unwrap();
412        factories.get(name).map(|reg| reg.status.clone())
413    }
414
415    /// Enable/disable plugin
416    pub fn set_plugin_status(&self, name: &str, status: PluginStatus) -> Result<()> {
417        let mut factories = self.factories.write().unwrap();
418        let registration = factories
419            .get_mut(name)
420            .ok_or_else(|| OptimError::PluginNotFound(name.to_string()))?;
421
422        let old_status = registration.status.clone();
423        registration.status = status.clone();
424
425        // Notify event listeners if status changed
426        if old_status != status {
427            drop(factories);
428            let mut listeners = self.event_listeners.write().unwrap();
429            for listener in listeners.iter_mut() {
430                listener.on_plugin_status_changed(name, &status);
431            }
432        }
433
434        Ok(())
435    }
436
437    /// Add plugin search path
438    pub fn add_search_path<P: AsRef<Path>>(&self, path: P) {
439        let mut search_paths = self.search_paths.write().unwrap();
440        search_paths.push(path.as_ref().to_path_buf());
441    }
442
443    /// Discover plugins in search paths
444    pub fn discover_plugins(&self) -> Result<usize> {
445        if !self.config.auto_discovery {
446            return Ok(0);
447        }
448
449        let search_paths = self.search_paths.read().unwrap();
450        let mut discovered_count = 0;
451
452        for path in search_paths.iter() {
453            if path.exists() && path.is_dir() {
454                discovered_count += self.discover_plugins_in_directory(path)?;
455            }
456        }
457
458        Ok(discovered_count)
459    }
460
461    /// Add event listener
462    pub fn add_event_listener(&self, listener: Box<dyn RegistryEventListener>) {
463        let mut listeners = self.event_listeners.write().unwrap();
464        listeners.push(listener);
465    }
466
467    /// Get cache statistics
468    pub fn get_cache_stats(&self) -> CacheStats {
469        let cache = self.cache.lock().unwrap();
470        cache.stats.clone()
471    }
472
473    /// Clear plugin cache
474    pub fn clear_cache(&self) {
475        let mut cache = self.cache.lock().unwrap();
476        cache.instances.clear();
477        cache.stats = CacheStats::default();
478    }
479
480    // Private helper methods
481
482    fn validate_plugin(&self, factory: &dyn PluginFactoryWrapper) -> Result<()> {
483        // Basic validation - check if plugin can be created
484        let config = factory.default_config();
485        let _optimizer = factory.create_f64(config)?;
486        Ok(())
487    }
488
489    fn matches_query(&self, info: &PluginInfo, query: &PluginQuery) -> bool {
490        // Check name pattern
491        if let Some(ref pattern) = query.name_pattern {
492            if !info.name.contains(pattern) {
493                return false;
494            }
495        }
496
497        // Check category
498        if let Some(ref category) = query.category {
499            if info.category != *category {
500                return false;
501            }
502        }
503
504        // Check data types
505        if !query.data_types.is_empty() {
506            let has_common_type = query
507                .data_types
508                .iter()
509                .any(|dt| info.supported_types.contains(dt));
510            if !has_common_type {
511                return false;
512            }
513        }
514
515        // Check tags
516        if !query.tags.is_empty() {
517            let has_common_tag = query.tags.iter().any(|tag| info.tags.contains(tag));
518            if !has_common_tag {
519                return false;
520            }
521        }
522
523        // Check version requirements
524        if let Some(ref version_req) = query.version_requirements {
525            if !self.version_matches(&info.version, version_req) {
526                return false;
527            }
528        }
529
530        true
531    }
532
533    fn version_matches(&self, version: &str, requirement: &VersionRequirement) -> bool {
534        // Simplified version matching - in practice would use semver
535        if let Some(ref exact) = requirement.exact_version {
536            return version == exact;
537        }
538
539        if let Some(ref min) = requirement.min_version {
540            if version < min.as_str() {
541                return false;
542            }
543        }
544
545        if let Some(ref max) = requirement.max_version {
546            if version >= max.as_str() {
547                return false;
548            }
549        }
550
551        true
552    }
553
554    fn discover_plugins_in_directory(&self, path: &Path) -> Result<usize> {
555        // In a real implementation, this would scan for plugin files
556        // and attempt to load them dynamically
557        Ok(0)
558    }
559
560    fn register_builtin_plugins(&mut self) {
561        // Register built-in plugins would go here
562        // For now, this is a placeholder
563    }
564}
565
566impl PluginCache {
567    fn new() -> Self {
568        Self {
569            instances: HashMap::new(),
570            stats: CacheStats::default(),
571        }
572    }
573}
574
575impl Default for RegistryConfig {
576    fn default() -> Self {
577        Self {
578            auto_discovery: true,
579            validate_on_registration: true,
580            enable_caching: true,
581            max_cache_size: 100,
582            load_timeout: std::time::Duration::from_secs(30),
583            enable_sandboxing: false,
584            allowed_sources: vec![
585                PluginSource::BuiltIn,
586                PluginSource::Local(PathBuf::from("./plugins")),
587            ],
588        }
589    }
590}
591
592// Helper macro for registering plugins
593#[macro_export]
594macro_rules! register_optimizer_plugin {
595    ($factory:expr) => {
596        $crate::plugin::PluginRegistry::global().register_plugin($factory)?
597    };
598}
599
600// Builder pattern for plugin queries
601pub struct PluginQueryBuilder {
602    query: PluginQuery,
603}
604
605impl Default for PluginQueryBuilder {
606    fn default() -> Self {
607        Self::new()
608    }
609}
610
611impl PluginQueryBuilder {
612    pub fn new() -> Self {
613        Self {
614            query: PluginQuery::default(),
615        }
616    }
617
618    pub fn name_pattern(mut self, pattern: &str) -> Self {
619        self.query.name_pattern = Some(pattern.to_string());
620        self
621    }
622
623    pub fn category(mut self, category: PluginCategory) -> Self {
624        self.query.category = Some(category);
625        self
626    }
627
628    pub fn data_type(mut self, datatype: DataType) -> Self {
629        self.query.data_types.push(datatype);
630        self
631    }
632
633    pub fn tag(mut self, tag: &str) -> Self {
634        self.query.tags.push(tag.to_string());
635        self
636    }
637
638    pub fn limit(mut self, limit: usize) -> Self {
639        self.query.limit = Some(limit);
640        self
641    }
642
643    pub fn build(self) -> PluginQuery {
644        self.query
645    }
646}
647
648#[cfg(test)]
649mod tests {
650    use super::*;
651
652    #[test]
653    fn test_plugin_registry_creation() {
654        let config = RegistryConfig::default();
655        let registry = PluginRegistry::new(config);
656        assert_eq!(registry.list_plugins().len(), 0);
657    }
658
659    #[test]
660    fn test_plugin_query_builder() {
661        let query = PluginQueryBuilder::new()
662            .name_pattern("adam")
663            .category(PluginCategory::FirstOrder)
664            .data_type(DataType::F32)
665            .limit(10)
666            .build();
667
668        assert_eq!(query.name_pattern, Some("adam".to_string()));
669        assert_eq!(query.category, Some(PluginCategory::FirstOrder));
670        assert_eq!(query.limit, Some(10));
671    }
672}