1use super::core::*;
7use crate::error::{OptimError, Result};
8use scirs2_core::numeric::Float;
9use std::collections::HashMap;
10use std::fmt::Debug;
11use std::path::{Path, PathBuf};
12use std::sync::{Mutex, RwLock};
13
14#[derive(Debug)]
16pub struct PluginRegistry {
17 factories: RwLock<HashMap<String, PluginRegistration>>,
19 search_paths: RwLock<Vec<PathBuf>>,
21 config: RegistryConfig,
23 cache: Mutex<PluginCache>,
25 event_listeners: RwLock<Vec<Box<dyn RegistryEventListener>>>,
27}
28
29#[derive(Debug)]
31pub struct PluginRegistration {
32 pub factory: Box<dyn PluginFactoryWrapper>,
34 pub info: PluginInfo,
36 pub registered_at: std::time::SystemTime,
38 pub status: PluginStatus,
40 pub load_count: usize,
42 pub last_used: Option<std::time::SystemTime>,
44}
45
46pub trait PluginFactoryWrapper: Debug + Send + Sync {
48 fn create_f32(&self, config: OptimizerConfig) -> Result<Box<dyn OptimizerPlugin<f32>>>;
50
51 fn create_f64(&self, config: OptimizerConfig) -> Result<Box<dyn OptimizerPlugin<f64>>>;
53
54 fn info(&self) -> PluginInfo;
56
57 fn validate_config(&self, config: &OptimizerConfig) -> Result<()>;
59
60 fn default_config(&self) -> OptimizerConfig;
62
63 fn config_schema(&self) -> ConfigSchema;
65
66 fn supports_type(&self, datatype: &DataType) -> bool;
68}
69
70#[derive(Debug, Clone, PartialEq, Eq)]
72pub enum PluginStatus {
73 Active,
75 Disabled,
77 Failed(String),
79 Deprecated,
81 Maintenance,
83}
84
85#[derive(Debug, Clone)]
87pub struct RegistryConfig {
88 pub auto_discovery: bool,
90 pub validate_on_registration: bool,
92 pub enable_caching: bool,
94 pub max_cache_size: usize,
96 pub load_timeout: std::time::Duration,
98 pub enable_sandboxing: bool,
100 pub allowed_sources: Vec<PluginSource>,
102}
103
104#[derive(Debug, Clone)]
106pub enum PluginSource {
107 BuiltIn,
109 Local(PathBuf),
111 Remote(String),
113 Package(String),
115}
116
117#[derive(Debug)]
119pub struct PluginCache {
120 instances: HashMap<String, CachedPlugin>,
122 stats: CacheStats,
124}
125
126#[derive(Debug)]
128pub struct CachedPlugin {
129 pub plugin: Box<dyn OptimizerPlugin<f64>>,
131 pub cached_at: std::time::SystemTime,
133 pub access_count: usize,
135 pub last_accessed: std::time::SystemTime,
137}
138
139#[derive(Debug, Default, Clone)]
141pub struct CacheStats {
142 pub hits: usize,
144 pub misses: usize,
146 pub evictions: usize,
148 pub memory_used: usize,
150}
151
152pub trait RegistryEventListener: Debug + Send + Sync {
154 fn on_plugin_registered(&mut self, info: &PluginInfo) {}
156
157 fn on_plugin_unregistered(&mut self, name: &str) {}
159
160 fn on_plugin_loaded(&mut self, name: &str) {}
162
163 fn on_plugin_load_failed(&mut self, _name: &str, error: &str) {}
165
166 fn on_plugin_status_changed(&mut self, _name: &str, status: &PluginStatus) {}
168}
169
170#[derive(Debug, Clone, Default)]
172pub struct PluginQuery {
173 pub name_pattern: Option<String>,
175 pub category: Option<PluginCategory>,
177 pub required_capabilities: Vec<String>,
179 pub data_types: Vec<DataType>,
181 pub version_requirements: Option<VersionRequirement>,
183 pub tags: Vec<String>,
185 pub limit: Option<usize>,
187}
188
189#[derive(Debug, Clone)]
191pub struct VersionRequirement {
192 pub min_version: Option<String>,
194 pub max_version: Option<String>,
196 pub exact_version: Option<String>,
198}
199
200#[derive(Debug, Clone)]
202pub struct PluginSearchResult {
203 pub plugins: Vec<PluginInfo>,
205 pub total_count: usize,
207 pub query: PluginQuery,
209 pub search_time: std::time::Duration,
211}
212
213impl PluginRegistry {
214 pub fn new(config: RegistryConfig) -> Self {
216 Self {
217 factories: RwLock::new(HashMap::new()),
218 search_paths: RwLock::new(Vec::new()),
219 config,
220 cache: Mutex::new(PluginCache::new()),
221 event_listeners: RwLock::new(Vec::new()),
222 }
223 }
224
225 pub fn global() -> &'static Self {
227 static INSTANCE: std::sync::OnceLock<PluginRegistry> = std::sync::OnceLock::new();
228 INSTANCE.get_or_init(|| {
229 let config = RegistryConfig::default();
230 let mut registry = PluginRegistry::new(config);
231 registry.register_builtin_plugins();
232 registry
233 })
234 }
235
236 pub fn register_plugin<F>(&self, factory: F) -> Result<()>
238 where
239 F: PluginFactoryWrapper + 'static,
240 {
241 let info = factory.info();
242 let name = info.name.clone();
243
244 if self.config.validate_on_registration {
246 self.validate_plugin(&factory)?;
247 }
248
249 let registration = PluginRegistration {
250 factory: Box::new(factory),
251 info: info.clone(),
252 registered_at: std::time::SystemTime::now(),
253 status: PluginStatus::Active,
254 load_count: 0,
255 last_used: None,
256 };
257
258 {
259 let mut factories = self.factories.write().unwrap();
260 factories.insert(name.clone(), registration);
261 }
262
263 {
265 let mut listeners = self.event_listeners.write().unwrap();
266 for listener in listeners.iter_mut() {
267 listener.on_plugin_registered(&info);
268 }
269 }
270
271 Ok(())
272 }
273
274 pub fn unregister_plugin(&self, name: &str) -> Result<()> {
276 let mut factories = self.factories.write().unwrap();
277 if factories.remove(name).is_some() {
278 drop(factories);
280 let mut listeners = self.event_listeners.write().unwrap();
281 for listener in listeners.iter_mut() {
282 listener.on_plugin_unregistered(name);
283 }
284 Ok(())
285 } else {
286 Err(OptimError::PluginNotFound(name.to_string()))
287 }
288 }
289
290 pub fn create_optimizer<A>(
292 &self,
293 name: &str,
294 config: OptimizerConfig,
295 ) -> Result<Box<dyn OptimizerPlugin<A>>>
296 where
297 A: Float + Debug + Send + Sync + 'static,
298 {
299 let factories = self.factories.read().unwrap();
300 let registration = factories
301 .get(name)
302 .ok_or_else(|| OptimError::PluginNotFound(name.to_string()))?;
303
304 match registration.status {
306 PluginStatus::Active => {}
307 PluginStatus::Disabled => {
308 return Err(OptimError::PluginDisabled(name.to_string()));
309 }
310 PluginStatus::Failed(ref error) => {
311 return Err(OptimError::PluginLoadError(error.clone()));
312 }
313 PluginStatus::Deprecated => {
314 eprintln!("Warning: Plugin '{}' is deprecated", name);
316 }
317 PluginStatus::Maintenance => {
318 return Err(OptimError::PluginInMaintenance(name.to_string()));
319 }
320 }
321
322 registration.factory.validate_config(&config)?;
324
325 let optimizer = if std::any::TypeId::of::<A>() == std::any::TypeId::of::<f32>() {
327 let opt = registration.factory.create_f32(config)?;
328 unsafe {
330 std::mem::transmute::<Box<dyn OptimizerPlugin<f32>>, Box<dyn OptimizerPlugin<A>>>(
331 opt,
332 )
333 }
334 } else if std::any::TypeId::of::<A>() == std::any::TypeId::of::<f64>() {
335 let opt = registration.factory.create_f64(config)?;
336 unsafe {
338 std::mem::transmute::<Box<dyn OptimizerPlugin<f64>>, Box<dyn OptimizerPlugin<A>>>(
339 opt,
340 )
341 }
342 } else {
343 return Err(OptimError::UnsupportedDataType(format!(
344 "Type {} not supported",
345 std::any::type_name::<A>()
346 )));
347 };
348
349 drop(factories);
351 let mut factories = self.factories.write().unwrap();
352 if let Some(registration) = factories.get_mut(name) {
353 registration.load_count += 1;
354 registration.last_used = Some(std::time::SystemTime::now());
355 }
356
357 drop(factories);
359 let mut listeners = self.event_listeners.write().unwrap();
360 for listener in listeners.iter_mut() {
361 listener.on_plugin_loaded(name);
362 }
363
364 Ok(optimizer)
365 }
366
367 pub fn list_plugins(&self) -> Vec<PluginInfo> {
369 let factories = self.factories.read().unwrap();
370 factories.values().map(|reg| reg.info.clone()).collect()
371 }
372
373 pub fn search_plugins(&self, query: PluginQuery) -> PluginSearchResult {
375 let start_time = std::time::Instant::now();
376 let factories = self.factories.read().unwrap();
377
378 let mut matching_plugins = Vec::new();
379
380 for registration in factories.values() {
381 if self.matches_query(®istration.info, &query) {
382 matching_plugins.push(registration.info.clone());
383 }
384 }
385
386 let total_count = matching_plugins.len();
387
388 if let Some(limit) = query.limit {
390 matching_plugins.truncate(limit);
391 }
392
393 let search_time = start_time.elapsed();
394
395 PluginSearchResult {
396 plugins: matching_plugins,
397 total_count,
398 query,
399 search_time,
400 }
401 }
402
403 pub fn get_plugin_info(&self, name: &str) -> Option<PluginInfo> {
405 let factories = self.factories.read().unwrap();
406 factories.get(name).map(|reg| reg.info.clone())
407 }
408
409 pub fn get_plugin_status(&self, name: &str) -> Option<PluginStatus> {
411 let factories = self.factories.read().unwrap();
412 factories.get(name).map(|reg| reg.status.clone())
413 }
414
415 pub fn set_plugin_status(&self, name: &str, status: PluginStatus) -> Result<()> {
417 let mut factories = self.factories.write().unwrap();
418 let registration = factories
419 .get_mut(name)
420 .ok_or_else(|| OptimError::PluginNotFound(name.to_string()))?;
421
422 let old_status = registration.status.clone();
423 registration.status = status.clone();
424
425 if old_status != status {
427 drop(factories);
428 let mut listeners = self.event_listeners.write().unwrap();
429 for listener in listeners.iter_mut() {
430 listener.on_plugin_status_changed(name, &status);
431 }
432 }
433
434 Ok(())
435 }
436
437 pub fn add_search_path<P: AsRef<Path>>(&self, path: P) {
439 let mut search_paths = self.search_paths.write().unwrap();
440 search_paths.push(path.as_ref().to_path_buf());
441 }
442
443 pub fn discover_plugins(&self) -> Result<usize> {
445 if !self.config.auto_discovery {
446 return Ok(0);
447 }
448
449 let search_paths = self.search_paths.read().unwrap();
450 let mut discovered_count = 0;
451
452 for path in search_paths.iter() {
453 if path.exists() && path.is_dir() {
454 discovered_count += self.discover_plugins_in_directory(path)?;
455 }
456 }
457
458 Ok(discovered_count)
459 }
460
461 pub fn add_event_listener(&self, listener: Box<dyn RegistryEventListener>) {
463 let mut listeners = self.event_listeners.write().unwrap();
464 listeners.push(listener);
465 }
466
467 pub fn get_cache_stats(&self) -> CacheStats {
469 let cache = self.cache.lock().unwrap();
470 cache.stats.clone()
471 }
472
473 pub fn clear_cache(&self) {
475 let mut cache = self.cache.lock().unwrap();
476 cache.instances.clear();
477 cache.stats = CacheStats::default();
478 }
479
480 fn validate_plugin(&self, factory: &dyn PluginFactoryWrapper) -> Result<()> {
483 let config = factory.default_config();
485 let _optimizer = factory.create_f64(config)?;
486 Ok(())
487 }
488
489 fn matches_query(&self, info: &PluginInfo, query: &PluginQuery) -> bool {
490 if let Some(ref pattern) = query.name_pattern {
492 if !info.name.contains(pattern) {
493 return false;
494 }
495 }
496
497 if let Some(ref category) = query.category {
499 if info.category != *category {
500 return false;
501 }
502 }
503
504 if !query.data_types.is_empty() {
506 let has_common_type = query
507 .data_types
508 .iter()
509 .any(|dt| info.supported_types.contains(dt));
510 if !has_common_type {
511 return false;
512 }
513 }
514
515 if !query.tags.is_empty() {
517 let has_common_tag = query.tags.iter().any(|tag| info.tags.contains(tag));
518 if !has_common_tag {
519 return false;
520 }
521 }
522
523 if let Some(ref version_req) = query.version_requirements {
525 if !self.version_matches(&info.version, version_req) {
526 return false;
527 }
528 }
529
530 true
531 }
532
533 fn version_matches(&self, version: &str, requirement: &VersionRequirement) -> bool {
534 if let Some(ref exact) = requirement.exact_version {
536 return version == exact;
537 }
538
539 if let Some(ref min) = requirement.min_version {
540 if version < min.as_str() {
541 return false;
542 }
543 }
544
545 if let Some(ref max) = requirement.max_version {
546 if version >= max.as_str() {
547 return false;
548 }
549 }
550
551 true
552 }
553
554 fn discover_plugins_in_directory(&self, path: &Path) -> Result<usize> {
555 Ok(0)
558 }
559
560 fn register_builtin_plugins(&mut self) {
561 }
564}
565
566impl PluginCache {
567 fn new() -> Self {
568 Self {
569 instances: HashMap::new(),
570 stats: CacheStats::default(),
571 }
572 }
573}
574
575impl Default for RegistryConfig {
576 fn default() -> Self {
577 Self {
578 auto_discovery: true,
579 validate_on_registration: true,
580 enable_caching: true,
581 max_cache_size: 100,
582 load_timeout: std::time::Duration::from_secs(30),
583 enable_sandboxing: false,
584 allowed_sources: vec![
585 PluginSource::BuiltIn,
586 PluginSource::Local(PathBuf::from("./plugins")),
587 ],
588 }
589 }
590}
591
592#[macro_export]
594macro_rules! register_optimizer_plugin {
595 ($factory:expr) => {
596 $crate::plugin::PluginRegistry::global().register_plugin($factory)?
597 };
598}
599
600pub struct PluginQueryBuilder {
602 query: PluginQuery,
603}
604
605impl Default for PluginQueryBuilder {
606 fn default() -> Self {
607 Self::new()
608 }
609}
610
611impl PluginQueryBuilder {
612 pub fn new() -> Self {
613 Self {
614 query: PluginQuery::default(),
615 }
616 }
617
618 pub fn name_pattern(mut self, pattern: &str) -> Self {
619 self.query.name_pattern = Some(pattern.to_string());
620 self
621 }
622
623 pub fn category(mut self, category: PluginCategory) -> Self {
624 self.query.category = Some(category);
625 self
626 }
627
628 pub fn data_type(mut self, datatype: DataType) -> Self {
629 self.query.data_types.push(datatype);
630 self
631 }
632
633 pub fn tag(mut self, tag: &str) -> Self {
634 self.query.tags.push(tag.to_string());
635 self
636 }
637
638 pub fn limit(mut self, limit: usize) -> Self {
639 self.query.limit = Some(limit);
640 self
641 }
642
643 pub fn build(self) -> PluginQuery {
644 self.query
645 }
646}
647
648#[cfg(test)]
649mod tests {
650 use super::*;
651
652 #[test]
653 fn test_plugin_registry_creation() {
654 let config = RegistryConfig::default();
655 let registry = PluginRegistry::new(config);
656 assert_eq!(registry.list_plugins().len(), 0);
657 }
658
659 #[test]
660 fn test_plugin_query_builder() {
661 let query = PluginQueryBuilder::new()
662 .name_pattern("adam")
663 .category(PluginCategory::FirstOrder)
664 .data_type(DataType::F32)
665 .limit(10)
666 .build();
667
668 assert_eq!(query.name_pattern, Some("adam".to_string()));
669 assert_eq!(query.category, Some(PluginCategory::FirstOrder));
670 assert_eq!(query.limit, Some(10));
671 }
672}