use crate::error::{OptimError, Result};
use scirs2_core::ndarray::{Array1, Array2};
use scirs2_core::numeric::Float;
use serde::{Deserialize, Serialize};
use std::any::Any;
use std::collections::HashMap;
use std::fmt::Debug;
use std::time::Duration;
pub trait OptimizerPlugin<A: Float>: Debug + Send + Sync {
fn step(&mut self, params: &Array1<A>, gradients: &Array1<A>) -> Result<Array1<A>>;
fn name(&self) -> &str;
fn version(&self) -> &str;
fn plugin_info(&self) -> PluginInfo;
fn capabilities(&self) -> PluginCapabilities;
fn initialize(&mut self, paramshape: &[usize]) -> Result<()>;
fn reset(&mut self) -> Result<()>;
fn get_config(&self) -> OptimizerConfig;
fn set_config(&mut self, config: OptimizerConfig) -> Result<()>;
fn get_state(&self) -> Result<OptimizerState>;
fn set_state(&mut self, state: OptimizerState) -> Result<()>;
fn clone_plugin(&self) -> Box<dyn OptimizerPlugin<A>>;
fn memory_usage(&self) -> MemoryUsage {
MemoryUsage::default()
}
fn performance_metrics(&self) -> PerformanceMetrics {
PerformanceMetrics::default()
}
}
pub trait ExtendedOptimizerPlugin<A: Float>: OptimizerPlugin<A> {
fn batch_step(&mut self, params: &Array2<A>, gradients: &Array2<A>) -> Result<Array2<A>>;
fn adaptive_learning_rate(&self, gradients: &Array1<A>) -> A;
fn preprocess_gradients(&self, gradients: &Array1<A>) -> Result<Array1<A>>;
fn postprocess_parameters(&self, params: &Array1<A>) -> Result<Array1<A>>;
fn get_trajectory(&self) -> Vec<Array1<A>>;
fn convergence_metrics(&self) -> ConvergenceMetrics;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PluginInfo {
pub name: String,
pub version: String,
pub author: String,
pub description: String,
pub homepage: Option<String>,
pub license: String,
pub supported_types: Vec<DataType>,
pub category: PluginCategory,
pub tags: Vec<String>,
pub min_sdk_version: String,
pub dependencies: Vec<PluginDependency>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct PluginCapabilities {
pub sparse_gradients: bool,
pub parameter_groups: bool,
pub momentum: bool,
pub adaptive_learning_rate: bool,
pub weight_decay: bool,
pub gradient_clipping: bool,
pub batch_processing: bool,
pub state_serialization: bool,
pub thread_safe: bool,
pub memory_efficient: bool,
pub gpu_support: bool,
pub simd_optimized: bool,
pub custom_loss_functions: bool,
pub regularization: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub enum DataType {
F32,
F64,
I32,
I64,
Complex32,
Complex64,
Custom(String),
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub enum PluginCategory {
FirstOrder,
SecondOrder,
Specialized,
MetaLearning,
Experimental,
Utility,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PluginDependency {
pub name: String,
pub version: String,
pub optional: bool,
pub dependency_type: DependencyType,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum DependencyType {
Plugin,
SystemLibrary,
Crate,
Runtime,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OptimizerConfig {
pub learning_rate: f64,
pub weight_decay: f64,
pub momentum: f64,
pub gradient_clip: Option<f64>,
pub custom_params: HashMap<String, ConfigValue>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ConfigValue {
Float(f64),
Integer(i64),
Boolean(bool),
String(String),
Array(Vec<f64>),
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct OptimizerState {
pub state_vectors: HashMap<String, Vec<f64>>,
pub step_count: usize,
pub custom_state: HashMap<String, StateValue>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum StateValue {
Float(f64),
Integer(i64),
Boolean(bool),
String(String),
Array(Vec<f64>),
Matrix(Vec<Vec<f64>>),
}
#[derive(Debug, Clone, Default)]
pub struct MemoryUsage {
pub current_usage: usize,
pub peak_usage: usize,
pub efficiency_score: f64,
}
#[derive(Debug, Clone, Default)]
pub struct PerformanceMetrics {
pub avg_step_time: f64,
pub total_steps: usize,
pub throughput: f64,
pub cpu_utilization: f64,
}
#[derive(Debug, Clone, Default)]
pub struct ConvergenceMetrics {
pub gradient_norm: f64,
pub parameter_change_norm: f64,
pub loss_improvement_rate: f64,
pub convergence_score: f64,
}
#[derive(Debug, Clone)]
pub struct PluginValidationResult {
pub is_valid: bool,
pub errors: Vec<String>,
pub warnings: Vec<String>,
pub benchmark_results: Option<BenchmarkResults>,
}
#[derive(Debug, Clone)]
pub struct BenchmarkResults {
pub execution_times: Vec<Duration>,
pub memory_usage: Vec<usize>,
pub accuracy_scores: Vec<f64>,
pub convergence_rates: Vec<f64>,
}
pub trait OptimizerPluginFactory<A: Float>: Debug + Send + Sync {
fn create_optimizer(&self, config: OptimizerConfig) -> Result<Box<dyn OptimizerPlugin<A>>>;
fn factory_info(&self) -> PluginInfo;
fn validate_config(&self, config: &OptimizerConfig) -> Result<()>;
fn default_config(&self) -> OptimizerConfig;
fn config_schema(&self) -> ConfigSchema;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConfigSchema {
pub fields: HashMap<String, FieldSchema>,
pub required_fields: Vec<String>,
pub version: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FieldSchema {
pub field_type: FieldType,
pub description: String,
pub default_value: Option<ConfigValue>,
pub constraints: Vec<ValidationConstraint>,
pub required: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum FieldType {
Float {
min: Option<f64>,
max: Option<f64>,
},
Integer {
min: Option<i64>,
max: Option<i64>,
},
Boolean,
String {
max_length: Option<usize>,
},
Array {
element_type: Box<FieldType>,
max_length: Option<usize>,
},
Choice {
options: Vec<String>,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ValidationConstraint {
Min(f64),
Max(f64),
Positive,
NonNegative,
Range(f64, f64),
Pattern(String),
Custom(String),
}
pub trait PluginLifecycle {
fn on_load(&mut self) -> Result<()> {
Ok(())
}
fn on_unload(&mut self) -> Result<()> {
Ok(())
}
fn on_enable(&mut self) -> Result<()> {
Ok(())
}
fn on_disable(&mut self) -> Result<()> {
Ok(())
}
fn on_maintenance(&mut self) -> Result<()> {
Ok(())
}
}
pub trait PluginEventHandler {
fn on_step(&mut self, _step: usize, _params: &Array1<f64>, gradients: &Array1<f64>) {}
fn on_convergence(&mut self, _finalparams: &Array1<f64>) {}
fn on_error(&mut self, error: &OptimError) {}
fn on_custom_event(&mut self, _event_name: &str, data: &dyn Any) {}
}
pub trait PluginMetadata {
fn documentation(&self) -> String {
String::new()
}
fn examples(&self) -> Vec<PluginExample> {
Vec::new()
}
fn changelog(&self) -> String {
String::new()
}
fn compatibility(&self) -> CompatibilityInfo {
CompatibilityInfo::default()
}
}
#[derive(Debug, Clone)]
pub struct PluginExample {
pub title: String,
pub description: String,
pub code: String,
pub expected_output: String,
}
#[derive(Debug, Clone, Default)]
pub struct CompatibilityInfo {
pub rust_versions: Vec<String>,
pub platforms: Vec<String>,
pub known_issues: Vec<String>,
pub breaking_changes: Vec<String>,
}
impl Default for PluginInfo {
fn default() -> Self {
Self {
name: "Unknown".to_string(),
version: "0.1.0".to_string(),
author: "Unknown".to_string(),
description: "No description provided".to_string(),
homepage: None,
license: "MIT".to_string(),
supported_types: vec![DataType::F32, DataType::F64],
category: PluginCategory::FirstOrder,
tags: Vec::new(),
min_sdk_version: "0.1.0".to_string(),
dependencies: Vec::new(),
}
}
}
impl Default for OptimizerConfig {
fn default() -> Self {
Self {
learning_rate: 0.001,
weight_decay: 0.0,
momentum: 0.0,
gradient_clip: None,
custom_params: HashMap::new(),
}
}
}
#[allow(dead_code)]
pub fn create_plugin_info(name: &str, version: &str, author: &str) -> PluginInfo {
PluginInfo {
name: name.to_string(),
version: version.to_string(),
author: author.to_string(),
..Default::default()
}
}
#[allow(dead_code)]
pub fn create_basic_capabilities() -> PluginCapabilities {
PluginCapabilities {
state_serialization: true,
thread_safe: true,
..Default::default()
}
}
#[allow(dead_code)]
pub fn validate_config_against_schema(
config: &OptimizerConfig,
schema: &ConfigSchema,
) -> Result<()> {
for required_field in &schema.required_fields {
match required_field.as_str() {
"learning_rate" => {
if config.learning_rate <= 0.0 {
return Err(OptimError::InvalidConfig(
"Learning rate must be positive".to_string(),
));
}
}
"weight_decay" => {
if config.weight_decay < 0.0 {
return Err(OptimError::InvalidConfig(
"Weight decay must be non-negative".to_string(),
));
}
}
_ => {
if !config.custom_params.contains_key(required_field) {
return Err(OptimError::InvalidConfig(format!(
"Required field '{}' is missing",
required_field
)));
}
}
}
}
for (field_name, field_schema) in &schema.fields {
let value = match field_name.as_str() {
"learning_rate" => Some(ConfigValue::Float(config.learning_rate)),
"weight_decay" => Some(ConfigValue::Float(config.weight_decay)),
"momentum" => Some(ConfigValue::Float(config.momentum)),
_ => config.custom_params.get(field_name).cloned(),
};
if let Some(value) = value {
validate_field_value(&value, field_schema)?;
} else if field_schema.required {
return Err(OptimError::InvalidConfig(format!(
"Required field '{}' is missing",
field_name
)));
}
}
Ok(())
}
#[allow(dead_code)]
fn validate_field_value(value: &ConfigValue, schema: &FieldSchema) -> Result<()> {
for constraint in &schema.constraints {
match (value, constraint) {
(ConfigValue::Float(v), ValidationConstraint::Min(min)) if v < min => {
return Err(OptimError::InvalidConfig(format!(
"Value {} is below minimum {}",
v, min
)));
}
(ConfigValue::Float(v), ValidationConstraint::Max(max)) if v > max => {
return Err(OptimError::InvalidConfig(format!(
"Value {} is above maximum {}",
v, max
)));
}
(ConfigValue::Float(v), ValidationConstraint::Positive) if *v <= 0.0 => {
return Err(OptimError::InvalidConfig(
"Value must be positive".to_string(),
));
}
(ConfigValue::Float(v), ValidationConstraint::NonNegative) if *v < 0.0 => {
return Err(OptimError::InvalidConfig(
"Value must be non-negative".to_string(),
));
}
(ConfigValue::Float(v), ValidationConstraint::Range(min, max))
if (v < min || v > max) =>
{
return Err(OptimError::InvalidConfig(format!(
"Value {} is outside range [{}, {}]",
v, min, max
)));
}
_ => {} }
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_plugin_info_default() {
let info = PluginInfo::default();
assert_eq!(info.name, "Unknown");
assert_eq!(info.version, "0.1.0");
}
#[test]
fn test_plugin_capabilities_default() {
let caps = PluginCapabilities::default();
assert!(!caps.sparse_gradients);
assert!(!caps.gpu_support);
}
#[test]
fn test_config_validation() {
let mut schema = ConfigSchema {
fields: HashMap::new(),
required_fields: vec!["learning_rate".to_string()],
version: "1.0".to_string(),
};
schema.fields.insert(
"learning_rate".to_string(),
FieldSchema {
field_type: FieldType::Float {
min: Some(0.0),
max: None,
},
description: "Learning rate".to_string(),
default_value: Some(ConfigValue::Float(0.001)),
constraints: vec![ValidationConstraint::Positive],
required: true,
},
);
let mut config = OptimizerConfig {
learning_rate: 0.001,
..Default::default()
};
assert!(validate_config_against_schema(&config, &schema).is_ok());
let mut config = OptimizerConfig {
learning_rate: -0.001,
..Default::default()
};
assert!(validate_config_against_schema(&config, &schema).is_err());
}
}