optirs_core/plugin/
core.rs

1// Core plugin traits and interfaces for optimizer development
2//
3// This module defines the fundamental traits and structures that custom optimizers
4// must implement to integrate with the plugin system.
5
6use crate::error::{OptimError, Result};
7use scirs2_core::ndarray::{Array1, Array2};
8use scirs2_core::numeric::Float;
9use serde::{Deserialize, Serialize};
10use std::any::Any;
11use std::collections::HashMap;
12use std::fmt::Debug;
13use std::time::Duration;
14
15/// Main trait for optimizer plugins
16pub trait OptimizerPlugin<A: Float>: Debug + Send + Sync {
17    /// Perform a single optimization step
18    fn step(&mut self, params: &Array1<A>, gradients: &Array1<A>) -> Result<Array1<A>>;
19
20    /// Get optimizer name
21    fn name(&self) -> &str;
22
23    /// Get optimizer version
24    fn version(&self) -> &str;
25
26    /// Get plugin information
27    fn plugin_info(&self) -> PluginInfo;
28
29    /// Get optimizer capabilities
30    fn capabilities(&self) -> PluginCapabilities;
31
32    /// Initialize optimizer with parameters
33    fn initialize(&mut self, paramshape: &[usize]) -> Result<()>;
34
35    /// Reset optimizer state
36    fn reset(&mut self) -> Result<()>;
37
38    /// Get optimizer configuration
39    fn get_config(&self) -> OptimizerConfig;
40
41    /// Set optimizer configuration
42    fn set_config(&mut self, config: OptimizerConfig) -> Result<()>;
43
44    /// Get optimizer state for serialization
45    fn get_state(&self) -> Result<OptimizerState>;
46
47    /// Set optimizer state from deserialization
48    fn set_state(&mut self, state: OptimizerState) -> Result<()>;
49
50    /// Clone the optimizer plugin
51    fn clone_plugin(&self) -> Box<dyn OptimizerPlugin<A>>;
52
53    /// Get memory usage information
54    fn memory_usage(&self) -> MemoryUsage {
55        MemoryUsage::default()
56    }
57
58    /// Get performance metrics
59    fn performance_metrics(&self) -> PerformanceMetrics {
60        PerformanceMetrics::default()
61    }
62}
63
64/// Extended plugin trait for optimizers with advanced features
65pub trait ExtendedOptimizerPlugin<A: Float>: OptimizerPlugin<A> {
66    /// Perform batch optimization step
67    fn batch_step(&mut self, params: &Array2<A>, gradients: &Array2<A>) -> Result<Array2<A>>;
68
69    /// Compute adaptive learning rate
70    fn adaptive_learning_rate(&self, gradients: &Array1<A>) -> A;
71
72    /// Gradient preprocessing
73    fn preprocess_gradients(&self, gradients: &Array1<A>) -> Result<Array1<A>>;
74
75    /// Parameter postprocessing
76    fn postprocess_parameters(&self, params: &Array1<A>) -> Result<Array1<A>>;
77
78    /// Get optimization trajectory
79    fn get_trajectory(&self) -> Vec<Array1<A>>;
80
81    /// Compute convergence metrics
82    fn convergence_metrics(&self) -> ConvergenceMetrics;
83}
84
85/// Plugin information and metadata
86#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct PluginInfo {
88    /// Plugin name
89    pub name: String,
90    /// Plugin version
91    pub version: String,
92    /// Plugin author
93    pub author: String,
94    /// Plugin description
95    pub description: String,
96    /// Plugin homepage/repository
97    pub homepage: Option<String>,
98    /// Plugin license
99    pub license: String,
100    /// Supported data types
101    pub supported_types: Vec<DataType>,
102    /// Plugin category
103    pub category: PluginCategory,
104    /// Plugin tags for search/filtering
105    pub tags: Vec<String>,
106    /// Minimum SDK version required
107    pub min_sdk_version: String,
108    /// Plugin dependencies
109    pub dependencies: Vec<PluginDependency>,
110}
111
112/// Plugin capabilities and features
113#[derive(Debug, Clone, Serialize, Deserialize, Default)]
114pub struct PluginCapabilities {
115    /// Supports sparse gradients
116    pub sparse_gradients: bool,
117    /// Supports parameter groups
118    pub parameter_groups: bool,
119    /// Supports momentum
120    pub momentum: bool,
121    /// Supports adaptive learning rates
122    pub adaptive_learning_rate: bool,
123    /// Supports weight decay
124    pub weight_decay: bool,
125    /// Supports gradient clipping
126    pub gradient_clipping: bool,
127    /// Supports batch processing
128    pub batch_processing: bool,
129    /// Supports state serialization
130    pub state_serialization: bool,
131    /// Thread safety
132    pub thread_safe: bool,
133    /// Memory efficient
134    pub memory_efficient: bool,
135    /// GPU acceleration support
136    pub gpu_support: bool,
137    /// SIMD optimization
138    pub simd_optimized: bool,
139    /// Supports custom loss functions
140    pub custom_loss_functions: bool,
141    /// Supports regularization
142    pub regularization: bool,
143}
144
145/// Supported data types
146#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
147pub enum DataType {
148    F32,
149    F64,
150    I32,
151    I64,
152    Complex32,
153    Complex64,
154    Custom(String),
155}
156
157/// Plugin categories
158#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
159pub enum PluginCategory {
160    /// First-order optimizers (SGD, Adam, etc.)
161    FirstOrder,
162    /// Second-order optimizers (Newton, BFGS, etc.)
163    SecondOrder,
164    /// Specialized optimizers (domain-specific)
165    Specialized,
166    /// Meta-learning optimizers
167    MetaLearning,
168    /// Experimental optimizers
169    Experimental,
170    /// Utility/helper plugins
171    Utility,
172}
173
174/// Plugin dependency information
175#[derive(Debug, Clone, Serialize, Deserialize)]
176pub struct PluginDependency {
177    /// Dependency name
178    pub name: String,
179    /// Version requirement
180    pub version: String,
181    /// Whether dependency is optional
182    pub optional: bool,
183    /// Dependency type
184    pub dependency_type: DependencyType,
185}
186
187/// Types of plugin dependencies
188#[derive(Debug, Clone, Serialize, Deserialize)]
189pub enum DependencyType {
190    /// Another plugin
191    Plugin,
192    /// System library
193    SystemLibrary,
194    /// Rust crate
195    Crate,
196    /// Runtime requirement
197    Runtime,
198}
199
200/// Optimizer configuration
201#[derive(Debug, Clone, Serialize, Deserialize)]
202pub struct OptimizerConfig {
203    /// Learning rate
204    pub learning_rate: f64,
205    /// Weight decay
206    pub weight_decay: f64,
207    /// Momentum coefficient
208    pub momentum: f64,
209    /// Gradient clipping threshold
210    pub gradient_clip: Option<f64>,
211    /// Custom parameters
212    pub custom_params: HashMap<String, ConfigValue>,
213}
214
215/// Configuration value types
216#[derive(Debug, Clone, Serialize, Deserialize)]
217pub enum ConfigValue {
218    Float(f64),
219    Integer(i64),
220    Boolean(bool),
221    String(String),
222    Array(Vec<f64>),
223}
224
225/// Optimizer state for serialization
226#[derive(Debug, Clone, Serialize, Deserialize, Default)]
227pub struct OptimizerState {
228    /// Internal state vectors
229    pub state_vectors: HashMap<String, Vec<f64>>,
230    /// Step count
231    pub step_count: usize,
232    /// Custom state data
233    pub custom_state: HashMap<String, StateValue>,
234}
235
236/// State value types
237#[derive(Debug, Clone, Serialize, Deserialize)]
238pub enum StateValue {
239    Float(f64),
240    Integer(i64),
241    Boolean(bool),
242    String(String),
243    Array(Vec<f64>),
244    Matrix(Vec<Vec<f64>>),
245}
246
247/// Memory usage information
248#[derive(Debug, Clone, Default)]
249pub struct MemoryUsage {
250    /// Current memory usage (bytes)
251    pub current_usage: usize,
252    /// Peak memory usage (bytes)
253    pub peak_usage: usize,
254    /// Memory efficiency score (0.0 to 1.0)
255    pub efficiency_score: f64,
256}
257
258/// Performance metrics
259#[derive(Debug, Clone, Default)]
260pub struct PerformanceMetrics {
261    /// Average step time (seconds)
262    pub avg_step_time: f64,
263    /// Total steps performed
264    pub total_steps: usize,
265    /// Throughput (steps per second)
266    pub throughput: f64,
267    /// CPU utilization (0.0 to 1.0)
268    pub cpu_utilization: f64,
269}
270
271/// Convergence metrics
272#[derive(Debug, Clone, Default)]
273pub struct ConvergenceMetrics {
274    /// Gradient norm
275    pub gradient_norm: f64,
276    /// Parameter change norm
277    pub parameter_change_norm: f64,
278    /// Loss improvement rate
279    pub loss_improvement_rate: f64,
280    /// Convergence score (0.0 to 1.0)
281    pub convergence_score: f64,
282}
283
284/// Plugin validation result
285#[derive(Debug, Clone)]
286pub struct PluginValidationResult {
287    /// Whether plugin is valid
288    pub is_valid: bool,
289    /// Validation errors
290    pub errors: Vec<String>,
291    /// Validation warnings
292    pub warnings: Vec<String>,
293    /// Performance benchmark results
294    pub benchmark_results: Option<BenchmarkResults>,
295}
296
297/// Benchmark results for plugin validation
298#[derive(Debug, Clone)]
299pub struct BenchmarkResults {
300    /// Execution time benchmarks
301    pub execution_times: Vec<Duration>,
302    /// Memory usage benchmarks
303    pub memory_usage: Vec<usize>,
304    /// Accuracy benchmarks
305    pub accuracy_scores: Vec<f64>,
306    /// Convergence benchmarks
307    pub convergence_rates: Vec<f64>,
308}
309
310/// Plugin factory trait for creating optimizer instances
311pub trait OptimizerPluginFactory<A: Float>: Debug + Send + Sync {
312    /// Create a new optimizer instance
313    fn create_optimizer(&self, config: OptimizerConfig) -> Result<Box<dyn OptimizerPlugin<A>>>;
314
315    /// Get factory information
316    fn factory_info(&self) -> PluginInfo;
317
318    /// Validate configuration
319    fn validate_config(&self, config: &OptimizerConfig) -> Result<()>;
320
321    /// Get default configuration
322    fn default_config(&self) -> OptimizerConfig;
323
324    /// Get configuration schema
325    fn config_schema(&self) -> ConfigSchema;
326}
327
328/// Configuration schema for validation and UI generation
329#[derive(Debug, Clone, Serialize, Deserialize)]
330pub struct ConfigSchema {
331    /// Schema fields
332    pub fields: HashMap<String, FieldSchema>,
333    /// Required fields
334    pub required_fields: Vec<String>,
335    /// Schema version
336    pub version: String,
337}
338
339/// Individual field schema
340#[derive(Debug, Clone, Serialize, Deserialize)]
341pub struct FieldSchema {
342    /// Field type
343    pub field_type: FieldType,
344    /// Field description
345    pub description: String,
346    /// Default value
347    pub default_value: Option<ConfigValue>,
348    /// Validation constraints
349    pub constraints: Vec<ValidationConstraint>,
350    /// Whether field is required
351    pub required: bool,
352}
353
354/// Field types for schema
355#[derive(Debug, Clone, Serialize, Deserialize)]
356pub enum FieldType {
357    Float {
358        min: Option<f64>,
359        max: Option<f64>,
360    },
361    Integer {
362        min: Option<i64>,
363        max: Option<i64>,
364    },
365    Boolean,
366    String {
367        max_length: Option<usize>,
368    },
369    Array {
370        element_type: Box<FieldType>,
371        max_length: Option<usize>,
372    },
373    Choice {
374        options: Vec<String>,
375    },
376}
377
378/// Validation constraints
379#[derive(Debug, Clone, Serialize, Deserialize)]
380pub enum ValidationConstraint {
381    /// Minimum value
382    Min(f64),
383    /// Maximum value
384    Max(f64),
385    /// Value must be positive
386    Positive,
387    /// Value must be non-negative
388    NonNegative,
389    /// Value must be in range
390    Range(f64, f64),
391    /// String must match regex pattern
392    Pattern(String),
393    /// Custom validation function name
394    Custom(String),
395}
396
397/// Plugin lifecycle hooks
398pub trait PluginLifecycle {
399    /// Called when plugin is loaded
400    fn on_load(&mut self) -> Result<()> {
401        Ok(())
402    }
403
404    /// Called when plugin is unloaded
405    fn on_unload(&mut self) -> Result<()> {
406        Ok(())
407    }
408
409    /// Called when plugin is enabled
410    fn on_enable(&mut self) -> Result<()> {
411        Ok(())
412    }
413
414    /// Called when plugin is disabled
415    fn on_disable(&mut self) -> Result<()> {
416        Ok(())
417    }
418
419    /// Called periodically for maintenance
420    fn on_maintenance(&mut self) -> Result<()> {
421        Ok(())
422    }
423}
424
425/// Plugin event system
426pub trait PluginEventHandler {
427    /// Handle optimization step event
428    fn on_step(&mut self, _step: usize, _params: &Array1<f64>, gradients: &Array1<f64>) {}
429
430    /// Handle convergence event
431    fn on_convergence(&mut self, _finalparams: &Array1<f64>) {}
432
433    /// Handle error event
434    fn on_error(&mut self, error: &OptimError) {}
435
436    /// Handle custom event
437    fn on_custom_event(&mut self, _event_name: &str, data: &dyn Any) {}
438}
439
440/// Plugin metadata provider
441pub trait PluginMetadata {
442    /// Get plugin documentation
443    fn documentation(&self) -> String {
444        String::new()
445    }
446
447    /// Get plugin examples
448    fn examples(&self) -> Vec<PluginExample> {
449        Vec::new()
450    }
451
452    /// Get plugin changelog
453    fn changelog(&self) -> String {
454        String::new()
455    }
456
457    /// Get plugin compatibility information
458    fn compatibility(&self) -> CompatibilityInfo {
459        CompatibilityInfo::default()
460    }
461}
462
463/// Plugin example
464#[derive(Debug, Clone)]
465pub struct PluginExample {
466    /// Example title
467    pub title: String,
468    /// Example description
469    pub description: String,
470    /// Example code
471    pub code: String,
472    /// Expected output
473    pub expected_output: String,
474}
475
476/// Compatibility information
477#[derive(Debug, Clone, Default)]
478pub struct CompatibilityInfo {
479    /// Supported Rust versions
480    pub rust_versions: Vec<String>,
481    /// Supported platforms
482    pub platforms: Vec<String>,
483    /// Known issues
484    pub known_issues: Vec<String>,
485    /// Breaking changes
486    pub breaking_changes: Vec<String>,
487}
488
489// Default implementations
490
491impl Default for PluginInfo {
492    fn default() -> Self {
493        Self {
494            name: "Unknown".to_string(),
495            version: "0.1.0".to_string(),
496            author: "Unknown".to_string(),
497            description: "No description provided".to_string(),
498            homepage: None,
499            license: "MIT".to_string(),
500            supported_types: vec![DataType::F32, DataType::F64],
501            category: PluginCategory::FirstOrder,
502            tags: Vec::new(),
503            min_sdk_version: "0.1.0".to_string(),
504            dependencies: Vec::new(),
505        }
506    }
507}
508
509impl Default for OptimizerConfig {
510    fn default() -> Self {
511        Self {
512            learning_rate: 0.001,
513            weight_decay: 0.0,
514            momentum: 0.0,
515            gradient_clip: None,
516            custom_params: HashMap::new(),
517        }
518    }
519}
520
521/// Utility functions for plugin development
522/// Create a basic plugin info structure
523#[allow(dead_code)]
524pub fn create_plugin_info(name: &str, version: &str, author: &str) -> PluginInfo {
525    PluginInfo {
526        name: name.to_string(),
527        version: version.to_string(),
528        author: author.to_string(),
529        ..Default::default()
530    }
531}
532
533/// Create basic plugin capabilities
534#[allow(dead_code)]
535pub fn create_basic_capabilities() -> PluginCapabilities {
536    PluginCapabilities {
537        state_serialization: true,
538        thread_safe: true,
539        ..Default::default()
540    }
541}
542
543/// Validate plugin configuration against schema
544#[allow(dead_code)]
545pub fn validate_config_against_schema(
546    config: &OptimizerConfig,
547    schema: &ConfigSchema,
548) -> Result<()> {
549    // Check required fields
550    for required_field in &schema.required_fields {
551        match required_field.as_str() {
552            "learning_rate" => {
553                if config.learning_rate <= 0.0 {
554                    return Err(OptimError::InvalidConfig(
555                        "Learning rate must be positive".to_string(),
556                    ));
557                }
558            }
559            "weight_decay" => {
560                if config.weight_decay < 0.0 {
561                    return Err(OptimError::InvalidConfig(
562                        "Weight decay must be non-negative".to_string(),
563                    ));
564                }
565            }
566            _ => {
567                if !config.custom_params.contains_key(required_field) {
568                    return Err(OptimError::InvalidConfig(format!(
569                        "Required field '{}' is missing",
570                        required_field
571                    )));
572                }
573            }
574        }
575    }
576
577    // Validate field constraints
578    for (field_name, field_schema) in &schema.fields {
579        let value = match field_name.as_str() {
580            "learning_rate" => Some(ConfigValue::Float(config.learning_rate)),
581            "weight_decay" => Some(ConfigValue::Float(config.weight_decay)),
582            "momentum" => Some(ConfigValue::Float(config.momentum)),
583            _ => config.custom_params.get(field_name).cloned(),
584        };
585
586        if let Some(value) = value {
587            validate_field_value(&value, field_schema)?;
588        } else if field_schema.required {
589            return Err(OptimError::InvalidConfig(format!(
590                "Required field '{}' is missing",
591                field_name
592            )));
593        }
594    }
595
596    Ok(())
597}
598
599/// Validate individual field value against schema
600#[allow(dead_code)]
601fn validate_field_value(value: &ConfigValue, schema: &FieldSchema) -> Result<()> {
602    for constraint in &schema.constraints {
603        match (value, constraint) {
604            (ConfigValue::Float(v), ValidationConstraint::Min(min)) => {
605                if v < min {
606                    return Err(OptimError::InvalidConfig(format!(
607                        "Value {} is below minimum {}",
608                        v, min
609                    )));
610                }
611            }
612            (ConfigValue::Float(v), ValidationConstraint::Max(max)) => {
613                if v > max {
614                    return Err(OptimError::InvalidConfig(format!(
615                        "Value {} is above maximum {}",
616                        v, max
617                    )));
618                }
619            }
620            (ConfigValue::Float(v), ValidationConstraint::Positive) => {
621                if *v <= 0.0 {
622                    return Err(OptimError::InvalidConfig(
623                        "Value must be positive".to_string(),
624                    ));
625                }
626            }
627            (ConfigValue::Float(v), ValidationConstraint::NonNegative) => {
628                if *v < 0.0 {
629                    return Err(OptimError::InvalidConfig(
630                        "Value must be non-negative".to_string(),
631                    ));
632                }
633            }
634            (ConfigValue::Float(v), ValidationConstraint::Range(min, max)) => {
635                if v < min || v > max {
636                    return Err(OptimError::InvalidConfig(format!(
637                        "Value {} is outside range [{}, {}]",
638                        v, min, max
639                    )));
640                }
641            }
642            _ => {} // Other constraint types can be added as needed
643        }
644    }
645    Ok(())
646}
647
648#[cfg(test)]
649mod tests {
650    use super::*;
651
652    #[test]
653    fn test_plugin_info_default() {
654        let info = PluginInfo::default();
655        assert_eq!(info.name, "Unknown");
656        assert_eq!(info.version, "0.1.0");
657    }
658
659    #[test]
660    fn test_plugin_capabilities_default() {
661        let caps = PluginCapabilities::default();
662        assert!(!caps.sparse_gradients);
663        assert!(!caps.gpu_support);
664    }
665
666    #[test]
667    fn test_config_validation() {
668        let mut schema = ConfigSchema {
669            fields: HashMap::new(),
670            required_fields: vec!["learning_rate".to_string()],
671            version: "1.0".to_string(),
672        };
673
674        schema.fields.insert(
675            "learning_rate".to_string(),
676            FieldSchema {
677                field_type: FieldType::Float {
678                    min: Some(0.0),
679                    max: None,
680                },
681                description: "Learning rate".to_string(),
682                default_value: Some(ConfigValue::Float(0.001)),
683                constraints: vec![ValidationConstraint::Positive],
684                required: true,
685            },
686        );
687
688        let mut config = OptimizerConfig {
689            learning_rate: 0.001,
690            ..Default::default()
691        };
692
693        assert!(validate_config_against_schema(&config, &schema).is_ok());
694
695        let mut config = OptimizerConfig {
696            learning_rate: -0.001,
697            ..Default::default()
698        };
699        assert!(validate_config_against_schema(&config, &schema).is_err());
700    }
701}