1use crate::error::{OptimError, Result};
7use scirs2_core::ndarray::{Array1, Array2};
8use scirs2_core::numeric::Float;
9use serde::{Deserialize, Serialize};
10use std::any::Any;
11use std::collections::HashMap;
12use std::fmt::Debug;
13use std::time::Duration;
14
15pub trait OptimizerPlugin<A: Float>: Debug + Send + Sync {
17 fn step(&mut self, params: &Array1<A>, gradients: &Array1<A>) -> Result<Array1<A>>;
19
20 fn name(&self) -> &str;
22
23 fn version(&self) -> &str;
25
26 fn plugin_info(&self) -> PluginInfo;
28
29 fn capabilities(&self) -> PluginCapabilities;
31
32 fn initialize(&mut self, paramshape: &[usize]) -> Result<()>;
34
35 fn reset(&mut self) -> Result<()>;
37
38 fn get_config(&self) -> OptimizerConfig;
40
41 fn set_config(&mut self, config: OptimizerConfig) -> Result<()>;
43
44 fn get_state(&self) -> Result<OptimizerState>;
46
47 fn set_state(&mut self, state: OptimizerState) -> Result<()>;
49
50 fn clone_plugin(&self) -> Box<dyn OptimizerPlugin<A>>;
52
53 fn memory_usage(&self) -> MemoryUsage {
55 MemoryUsage::default()
56 }
57
58 fn performance_metrics(&self) -> PerformanceMetrics {
60 PerformanceMetrics::default()
61 }
62}
63
64pub trait ExtendedOptimizerPlugin<A: Float>: OptimizerPlugin<A> {
66 fn batch_step(&mut self, params: &Array2<A>, gradients: &Array2<A>) -> Result<Array2<A>>;
68
69 fn adaptive_learning_rate(&self, gradients: &Array1<A>) -> A;
71
72 fn preprocess_gradients(&self, gradients: &Array1<A>) -> Result<Array1<A>>;
74
75 fn postprocess_parameters(&self, params: &Array1<A>) -> Result<Array1<A>>;
77
78 fn get_trajectory(&self) -> Vec<Array1<A>>;
80
81 fn convergence_metrics(&self) -> ConvergenceMetrics;
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct PluginInfo {
88 pub name: String,
90 pub version: String,
92 pub author: String,
94 pub description: String,
96 pub homepage: Option<String>,
98 pub license: String,
100 pub supported_types: Vec<DataType>,
102 pub category: PluginCategory,
104 pub tags: Vec<String>,
106 pub min_sdk_version: String,
108 pub dependencies: Vec<PluginDependency>,
110}
111
112#[derive(Debug, Clone, Serialize, Deserialize, Default)]
114pub struct PluginCapabilities {
115 pub sparse_gradients: bool,
117 pub parameter_groups: bool,
119 pub momentum: bool,
121 pub adaptive_learning_rate: bool,
123 pub weight_decay: bool,
125 pub gradient_clipping: bool,
127 pub batch_processing: bool,
129 pub state_serialization: bool,
131 pub thread_safe: bool,
133 pub memory_efficient: bool,
135 pub gpu_support: bool,
137 pub simd_optimized: bool,
139 pub custom_loss_functions: bool,
141 pub regularization: bool,
143}
144
145#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
147pub enum DataType {
148 F32,
149 F64,
150 I32,
151 I64,
152 Complex32,
153 Complex64,
154 Custom(String),
155}
156
157#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
159pub enum PluginCategory {
160 FirstOrder,
162 SecondOrder,
164 Specialized,
166 MetaLearning,
168 Experimental,
170 Utility,
172}
173
174#[derive(Debug, Clone, Serialize, Deserialize)]
176pub struct PluginDependency {
177 pub name: String,
179 pub version: String,
181 pub optional: bool,
183 pub dependency_type: DependencyType,
185}
186
187#[derive(Debug, Clone, Serialize, Deserialize)]
189pub enum DependencyType {
190 Plugin,
192 SystemLibrary,
194 Crate,
196 Runtime,
198}
199
200#[derive(Debug, Clone, Serialize, Deserialize)]
202pub struct OptimizerConfig {
203 pub learning_rate: f64,
205 pub weight_decay: f64,
207 pub momentum: f64,
209 pub gradient_clip: Option<f64>,
211 pub custom_params: HashMap<String, ConfigValue>,
213}
214
215#[derive(Debug, Clone, Serialize, Deserialize)]
217pub enum ConfigValue {
218 Float(f64),
219 Integer(i64),
220 Boolean(bool),
221 String(String),
222 Array(Vec<f64>),
223}
224
225#[derive(Debug, Clone, Serialize, Deserialize, Default)]
227pub struct OptimizerState {
228 pub state_vectors: HashMap<String, Vec<f64>>,
230 pub step_count: usize,
232 pub custom_state: HashMap<String, StateValue>,
234}
235
236#[derive(Debug, Clone, Serialize, Deserialize)]
238pub enum StateValue {
239 Float(f64),
240 Integer(i64),
241 Boolean(bool),
242 String(String),
243 Array(Vec<f64>),
244 Matrix(Vec<Vec<f64>>),
245}
246
247#[derive(Debug, Clone, Default)]
249pub struct MemoryUsage {
250 pub current_usage: usize,
252 pub peak_usage: usize,
254 pub efficiency_score: f64,
256}
257
258#[derive(Debug, Clone, Default)]
260pub struct PerformanceMetrics {
261 pub avg_step_time: f64,
263 pub total_steps: usize,
265 pub throughput: f64,
267 pub cpu_utilization: f64,
269}
270
271#[derive(Debug, Clone, Default)]
273pub struct ConvergenceMetrics {
274 pub gradient_norm: f64,
276 pub parameter_change_norm: f64,
278 pub loss_improvement_rate: f64,
280 pub convergence_score: f64,
282}
283
284#[derive(Debug, Clone)]
286pub struct PluginValidationResult {
287 pub is_valid: bool,
289 pub errors: Vec<String>,
291 pub warnings: Vec<String>,
293 pub benchmark_results: Option<BenchmarkResults>,
295}
296
297#[derive(Debug, Clone)]
299pub struct BenchmarkResults {
300 pub execution_times: Vec<Duration>,
302 pub memory_usage: Vec<usize>,
304 pub accuracy_scores: Vec<f64>,
306 pub convergence_rates: Vec<f64>,
308}
309
310pub trait OptimizerPluginFactory<A: Float>: Debug + Send + Sync {
312 fn create_optimizer(&self, config: OptimizerConfig) -> Result<Box<dyn OptimizerPlugin<A>>>;
314
315 fn factory_info(&self) -> PluginInfo;
317
318 fn validate_config(&self, config: &OptimizerConfig) -> Result<()>;
320
321 fn default_config(&self) -> OptimizerConfig;
323
324 fn config_schema(&self) -> ConfigSchema;
326}
327
328#[derive(Debug, Clone, Serialize, Deserialize)]
330pub struct ConfigSchema {
331 pub fields: HashMap<String, FieldSchema>,
333 pub required_fields: Vec<String>,
335 pub version: String,
337}
338
339#[derive(Debug, Clone, Serialize, Deserialize)]
341pub struct FieldSchema {
342 pub field_type: FieldType,
344 pub description: String,
346 pub default_value: Option<ConfigValue>,
348 pub constraints: Vec<ValidationConstraint>,
350 pub required: bool,
352}
353
354#[derive(Debug, Clone, Serialize, Deserialize)]
356pub enum FieldType {
357 Float {
358 min: Option<f64>,
359 max: Option<f64>,
360 },
361 Integer {
362 min: Option<i64>,
363 max: Option<i64>,
364 },
365 Boolean,
366 String {
367 max_length: Option<usize>,
368 },
369 Array {
370 element_type: Box<FieldType>,
371 max_length: Option<usize>,
372 },
373 Choice {
374 options: Vec<String>,
375 },
376}
377
378#[derive(Debug, Clone, Serialize, Deserialize)]
380pub enum ValidationConstraint {
381 Min(f64),
383 Max(f64),
385 Positive,
387 NonNegative,
389 Range(f64, f64),
391 Pattern(String),
393 Custom(String),
395}
396
397pub trait PluginLifecycle {
399 fn on_load(&mut self) -> Result<()> {
401 Ok(())
402 }
403
404 fn on_unload(&mut self) -> Result<()> {
406 Ok(())
407 }
408
409 fn on_enable(&mut self) -> Result<()> {
411 Ok(())
412 }
413
414 fn on_disable(&mut self) -> Result<()> {
416 Ok(())
417 }
418
419 fn on_maintenance(&mut self) -> Result<()> {
421 Ok(())
422 }
423}
424
425pub trait PluginEventHandler {
427 fn on_step(&mut self, _step: usize, _params: &Array1<f64>, gradients: &Array1<f64>) {}
429
430 fn on_convergence(&mut self, _finalparams: &Array1<f64>) {}
432
433 fn on_error(&mut self, error: &OptimError) {}
435
436 fn on_custom_event(&mut self, _event_name: &str, data: &dyn Any) {}
438}
439
440pub trait PluginMetadata {
442 fn documentation(&self) -> String {
444 String::new()
445 }
446
447 fn examples(&self) -> Vec<PluginExample> {
449 Vec::new()
450 }
451
452 fn changelog(&self) -> String {
454 String::new()
455 }
456
457 fn compatibility(&self) -> CompatibilityInfo {
459 CompatibilityInfo::default()
460 }
461}
462
463#[derive(Debug, Clone)]
465pub struct PluginExample {
466 pub title: String,
468 pub description: String,
470 pub code: String,
472 pub expected_output: String,
474}
475
476#[derive(Debug, Clone, Default)]
478pub struct CompatibilityInfo {
479 pub rust_versions: Vec<String>,
481 pub platforms: Vec<String>,
483 pub known_issues: Vec<String>,
485 pub breaking_changes: Vec<String>,
487}
488
489impl Default for PluginInfo {
492 fn default() -> Self {
493 Self {
494 name: "Unknown".to_string(),
495 version: "0.1.0".to_string(),
496 author: "Unknown".to_string(),
497 description: "No description provided".to_string(),
498 homepage: None,
499 license: "MIT".to_string(),
500 supported_types: vec![DataType::F32, DataType::F64],
501 category: PluginCategory::FirstOrder,
502 tags: Vec::new(),
503 min_sdk_version: "0.1.0".to_string(),
504 dependencies: Vec::new(),
505 }
506 }
507}
508
509impl Default for OptimizerConfig {
510 fn default() -> Self {
511 Self {
512 learning_rate: 0.001,
513 weight_decay: 0.0,
514 momentum: 0.0,
515 gradient_clip: None,
516 custom_params: HashMap::new(),
517 }
518 }
519}
520
521#[allow(dead_code)]
524pub fn create_plugin_info(name: &str, version: &str, author: &str) -> PluginInfo {
525 PluginInfo {
526 name: name.to_string(),
527 version: version.to_string(),
528 author: author.to_string(),
529 ..Default::default()
530 }
531}
532
533#[allow(dead_code)]
535pub fn create_basic_capabilities() -> PluginCapabilities {
536 PluginCapabilities {
537 state_serialization: true,
538 thread_safe: true,
539 ..Default::default()
540 }
541}
542
543#[allow(dead_code)]
545pub fn validate_config_against_schema(
546 config: &OptimizerConfig,
547 schema: &ConfigSchema,
548) -> Result<()> {
549 for required_field in &schema.required_fields {
551 match required_field.as_str() {
552 "learning_rate" => {
553 if config.learning_rate <= 0.0 {
554 return Err(OptimError::InvalidConfig(
555 "Learning rate must be positive".to_string(),
556 ));
557 }
558 }
559 "weight_decay" => {
560 if config.weight_decay < 0.0 {
561 return Err(OptimError::InvalidConfig(
562 "Weight decay must be non-negative".to_string(),
563 ));
564 }
565 }
566 _ => {
567 if !config.custom_params.contains_key(required_field) {
568 return Err(OptimError::InvalidConfig(format!(
569 "Required field '{}' is missing",
570 required_field
571 )));
572 }
573 }
574 }
575 }
576
577 for (field_name, field_schema) in &schema.fields {
579 let value = match field_name.as_str() {
580 "learning_rate" => Some(ConfigValue::Float(config.learning_rate)),
581 "weight_decay" => Some(ConfigValue::Float(config.weight_decay)),
582 "momentum" => Some(ConfigValue::Float(config.momentum)),
583 _ => config.custom_params.get(field_name).cloned(),
584 };
585
586 if let Some(value) = value {
587 validate_field_value(&value, field_schema)?;
588 } else if field_schema.required {
589 return Err(OptimError::InvalidConfig(format!(
590 "Required field '{}' is missing",
591 field_name
592 )));
593 }
594 }
595
596 Ok(())
597}
598
599#[allow(dead_code)]
601fn validate_field_value(value: &ConfigValue, schema: &FieldSchema) -> Result<()> {
602 for constraint in &schema.constraints {
603 match (value, constraint) {
604 (ConfigValue::Float(v), ValidationConstraint::Min(min)) => {
605 if v < min {
606 return Err(OptimError::InvalidConfig(format!(
607 "Value {} is below minimum {}",
608 v, min
609 )));
610 }
611 }
612 (ConfigValue::Float(v), ValidationConstraint::Max(max)) => {
613 if v > max {
614 return Err(OptimError::InvalidConfig(format!(
615 "Value {} is above maximum {}",
616 v, max
617 )));
618 }
619 }
620 (ConfigValue::Float(v), ValidationConstraint::Positive) => {
621 if *v <= 0.0 {
622 return Err(OptimError::InvalidConfig(
623 "Value must be positive".to_string(),
624 ));
625 }
626 }
627 (ConfigValue::Float(v), ValidationConstraint::NonNegative) => {
628 if *v < 0.0 {
629 return Err(OptimError::InvalidConfig(
630 "Value must be non-negative".to_string(),
631 ));
632 }
633 }
634 (ConfigValue::Float(v), ValidationConstraint::Range(min, max)) => {
635 if v < min || v > max {
636 return Err(OptimError::InvalidConfig(format!(
637 "Value {} is outside range [{}, {}]",
638 v, min, max
639 )));
640 }
641 }
642 _ => {} }
644 }
645 Ok(())
646}
647
648#[cfg(test)]
649mod tests {
650 use super::*;
651
652 #[test]
653 fn test_plugin_info_default() {
654 let info = PluginInfo::default();
655 assert_eq!(info.name, "Unknown");
656 assert_eq!(info.version, "0.1.0");
657 }
658
659 #[test]
660 fn test_plugin_capabilities_default() {
661 let caps = PluginCapabilities::default();
662 assert!(!caps.sparse_gradients);
663 assert!(!caps.gpu_support);
664 }
665
666 #[test]
667 fn test_config_validation() {
668 let mut schema = ConfigSchema {
669 fields: HashMap::new(),
670 required_fields: vec!["learning_rate".to_string()],
671 version: "1.0".to_string(),
672 };
673
674 schema.fields.insert(
675 "learning_rate".to_string(),
676 FieldSchema {
677 field_type: FieldType::Float {
678 min: Some(0.0),
679 max: None,
680 },
681 description: "Learning rate".to_string(),
682 default_value: Some(ConfigValue::Float(0.001)),
683 constraints: vec![ValidationConstraint::Positive],
684 required: true,
685 },
686 );
687
688 let mut config = OptimizerConfig {
689 learning_rate: 0.001,
690 ..Default::default()
691 };
692
693 assert!(validate_config_against_schema(&config, &schema).is_ok());
694
695 let mut config = OptimizerConfig {
696 learning_rate: -0.001,
697 ..Default::default()
698 };
699 assert!(validate_config_against_schema(&config, &schema).is_err());
700 }
701}