optirs_core/plugin/
sdk.rs

1// Plugin SDK utilities and helpers for optimizer development
2//
3// This module provides a comprehensive SDK for developing custom optimizer plugins,
4// including base classes, utilities, testing frameworks, and development tools.
5
6#[allow(dead_code)]
7use super::core::*;
8#[cfg(feature = "cross-platform-testing")]
9use crate::benchmarking::cross_platform_tester::{PerformanceBaseline, PlatformTarget};
10use crate::error::{OptimError, Result};
11use scirs2_core::ndarray::Array1;
12use scirs2_core::numeric::Float;
13use std::collections::HashMap;
14use std::fmt::Debug;
15
16/// Base optimizer plugin implementation with common functionality
17pub struct BaseOptimizerPlugin<A: Float + std::fmt::Debug> {
18    /// Plugin information
19    info: PluginInfo,
20    /// Plugin capabilities
21    capabilities: PluginCapabilities,
22    /// Optimizer configuration
23    config: OptimizerConfig,
24    /// Internal state
25    state: BaseOptimizerState<A>,
26    /// Performance metrics
27    metrics: PerformanceMetrics,
28    /// Memory usage tracking
29    memory_usage: MemoryUsage,
30    /// Event handlers
31    event_handlers: Vec<Box<dyn PluginEventHandler>>,
32}
33
34impl<A: Float + std::fmt::Debug + Send + Sync> std::fmt::Debug for BaseOptimizerPlugin<A> {
35    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36        f.debug_struct("BaseOptimizerPlugin")
37            .field("info", &self.info)
38            .field("capabilities", &self.capabilities)
39            .field("config", &self.config)
40            .field("state", &self.state)
41            .field("metrics", &self.metrics)
42            .field("memory_usage", &self.memory_usage)
43            .field(
44                "event_handlers",
45                &format!("{} handlers", self.event_handlers.len()),
46            )
47            .finish()
48    }
49}
50
51/// Base optimizer state
52#[derive(Debug, Clone)]
53pub struct BaseOptimizerState<A: Float + std::fmt::Debug> {
54    /// Step count
55    pub step_count: usize,
56    /// Parameter count
57    pub param_count: usize,
58    /// Learning rate history
59    pub lr_history: Vec<A>,
60    /// Gradient norms history
61    pub grad_norm_history: Vec<A>,
62    /// Parameter change norms history
63    pub param_change_history: Vec<A>,
64    /// Custom state data
65    pub custom_state: HashMap<String, StateValue>,
66}
67
68/// Plugin development utilities
69pub struct PluginSDK;
70
71/// Plugin testing framework
72pub struct PluginTester<A: Float> {
73    /// Test configuration
74    config: TestConfig,
75    /// Test suite
76    test_suite: TestSuite<A>,
77    /// Benchmark suite
78    benchmark_suite: BenchmarkSuite<A>,
79    /// Validation framework
80    validator: PluginValidator<A>,
81}
82
83/// Test configuration
84#[derive(Debug, Clone)]
85pub struct TestConfig {
86    /// Number of test iterations
87    pub iterations: usize,
88    /// Tolerance for numerical tests
89    pub tolerance: f64,
90    /// Random seed for reproducibility
91    pub random_seed: u64,
92    /// Enable performance testing
93    pub enable_performance_tests: bool,
94    /// Enable memory testing
95    pub enable_memory_tests: bool,
96    /// Enable convergence testing
97    pub enable_convergence_tests: bool,
98}
99
100/// Test suite for plugin validation
101#[derive(Debug)]
102pub struct TestSuite<A: Float> {
103    /// Functionality tests
104    pub functionality_tests: Vec<Box<dyn PluginTest<A>>>,
105    /// Performance tests
106    pub performance_tests: Vec<Box<dyn PerformanceTest<A>>>,
107    /// Convergence tests
108    pub convergence_tests: Vec<Box<dyn ConvergenceTest<A>>>,
109    /// Memory tests
110    pub memory_tests: Vec<Box<dyn MemoryTest<A>>>,
111}
112
113/// Individual plugin test trait
114pub trait PluginTest<A: Float>: Debug {
115    /// Run the test
116    fn run_test(&self, plugin: &mut dyn OptimizerPlugin<A>) -> TestResult;
117
118    /// Get test name
119    fn name(&self) -> &str;
120
121    /// Get test description
122    fn description(&self) -> &str;
123}
124
125/// Performance test trait
126pub trait PerformanceTest<A: Float>: Debug {
127    /// Run performance test
128    fn run_performance_test(&self, plugin: &mut dyn OptimizerPlugin<A>) -> PerformanceTestResult;
129
130    /// Get test name
131    fn name(&self) -> &str;
132
133    /// Get performance baseline
134    #[cfg(feature = "cross-platform-testing")]
135    fn baseline(&self) -> PerformanceBaseline;
136}
137
138/// Convergence test trait
139pub trait ConvergenceTest<A: Float>: Debug {
140    /// Run convergence test
141    fn run_convergence_test(&self, plugin: &mut dyn OptimizerPlugin<A>)
142        -> ConvergenceTestResult<A>;
143
144    /// Get test name
145    fn name(&self) -> &str;
146
147    /// Get convergence criteria
148    fn convergence_criteria(&self) -> ConvergenceCriteria<A>;
149}
150
151/// Memory test trait
152pub trait MemoryTest<A: Float>: Debug {
153    /// Run memory test
154    fn run_memory_test(&self, plugin: &mut dyn OptimizerPlugin<A>) -> MemoryTestResult;
155
156    /// Get test name
157    fn name(&self) -> &str;
158
159    /// Get memory constraints
160    fn memory_constraints(&self) -> MemoryConstraints;
161}
162
163/// Test result
164#[derive(Debug, Clone)]
165pub struct TestResult {
166    /// Test passed
167    pub passed: bool,
168    /// Test message
169    pub message: String,
170    /// Execution time
171    pub execution_time: std::time::Duration,
172    /// Additional data
173    pub data: HashMap<String, serde_json::Value>,
174}
175
176/// Performance test result
177#[derive(Debug, Clone)]
178pub struct PerformanceTestResult {
179    /// Performance metrics
180    pub metrics: PerformanceMetrics,
181    /// Comparison with baseline
182    pub baseline_comparison: BaselineComparison,
183    /// Performance score (0.0 to 1.0)
184    pub performance_score: f64,
185}
186
187/// Convergence test result
188#[derive(Debug, Clone)]
189pub struct ConvergenceTestResult<A: Float> {
190    /// Converged successfully
191    pub converged: bool,
192    /// Number of iterations to convergence
193    pub iterations_to_convergence: Option<usize>,
194    /// Final objective value
195    pub final_objective: A,
196    /// Convergence rate
197    pub convergence_rate: f64,
198    /// Convergence metrics
199    pub metrics: ConvergenceMetrics,
200}
201
202/// Memory test result
203#[derive(Debug, Clone)]
204pub struct MemoryTestResult {
205    /// Memory usage metrics
206    pub memory_metrics: MemoryUsage,
207    /// Memory leak detected
208    pub memory_leak_detected: bool,
209    /// Memory efficiency score
210    pub efficiency_score: f64,
211}
212
213/// Baseline comparison
214#[derive(Debug, Clone)]
215pub struct BaselineComparison {
216    /// Relative performance (baseline = 1.0)
217    pub relative_performance: f64,
218    /// Performance difference (absolute)
219    pub absolute_difference: f64,
220    /// Performance improvement (percentage)
221    pub improvement_percent: f64,
222}
223
224/// Convergence criteria
225#[derive(Debug, Clone)]
226pub struct ConvergenceCriteria<A: Float> {
227    /// Maximum iterations
228    pub max_iterations: usize,
229    /// Gradient norm tolerance
230    pub gradient_tolerance: A,
231    /// Function value tolerance
232    pub function_tolerance: A,
233    /// Parameter change tolerance
234    pub parameter_tolerance: A,
235}
236
237/// Memory constraints for testing
238#[derive(Debug, Clone)]
239pub struct MemoryConstraints {
240    /// Maximum memory usage (bytes)
241    pub max_memory_usage: usize,
242    /// Maximum allocation count
243    pub max_allocations: usize,
244    /// Memory leak tolerance (bytes)
245    pub leak_tolerance: usize,
246}
247
248/// Plugin validator for comprehensive validation
249#[derive(Debug)]
250pub struct PluginValidator<A: Float> {
251    /// Validation rules
252    rules: Vec<Box<dyn ValidationRule<A>>>,
253    /// Compatibility checker
254    compatibility_checker: CompatibilityChecker,
255}
256
257/// Validation rule trait
258pub trait ValidationRule<A: Float>: Debug {
259    /// Validate plugin
260    fn validate(&self, plugin: &dyn OptimizerPlugin<A>) -> ValidationResult;
261
262    /// Get rule name
263    fn name(&self) -> &str;
264
265    /// Get rule severity
266    fn severity(&self) -> ValidationSeverity;
267}
268
269/// Validation result
270#[derive(Debug, Clone)]
271pub struct ValidationResult {
272    /// Validation passed
273    pub passed: bool,
274    /// Validation message
275    pub message: String,
276    /// Severity level
277    pub severity: ValidationSeverity,
278    /// Suggestions for improvement
279    pub suggestions: Vec<String>,
280}
281
282/// Validation severity levels
283#[derive(Debug, Clone)]
284pub enum ValidationSeverity {
285    Info,
286    Warning,
287    Error,
288    Critical,
289}
290
291/// Compatibility checker
292#[derive(Debug)]
293pub struct CompatibilityChecker {
294    /// Target platforms
295    target_platforms: Vec<String>,
296    /// Rust version requirements
297    rust_versions: Vec<String>,
298    /// Dependency compatibility
299    dependency_compatibility: HashMap<String, String>,
300}
301
302/// Benchmark suite for performance evaluation
303#[derive(Debug)]
304pub struct BenchmarkSuite<A: Float> {
305    /// Standard benchmarks
306    standard_benchmarks: Vec<Box<dyn Benchmark<A>>>,
307    /// Custom benchmarks
308    custom_benchmarks: Vec<Box<dyn Benchmark<A>>>,
309    /// Benchmark configuration
310    config: BenchmarkConfig,
311}
312
313/// Benchmark trait
314pub trait Benchmark<A: Float>: Debug {
315    /// Run benchmark
316    fn run_benchmark(&self, plugin: &mut dyn OptimizerPlugin<A>) -> BenchmarkResult<A>;
317
318    /// Get benchmark name
319    fn name(&self) -> &str;
320
321    /// Get benchmark description
322    fn description(&self) -> &str;
323
324    /// Get benchmark category
325    fn category(&self) -> BenchmarkCategory;
326}
327
328/// Benchmark categories
329#[derive(Debug, Clone)]
330pub enum BenchmarkCategory {
331    /// Speed benchmarks
332    Speed,
333    /// Memory benchmarks
334    Memory,
335    /// Accuracy benchmarks
336    Accuracy,
337    /// Scalability benchmarks
338    Scalability,
339    /// Robustness benchmarks
340    Robustness,
341}
342
343/// Benchmark result
344#[derive(Debug, Clone)]
345pub struct BenchmarkResult<A: Float> {
346    /// Benchmark name
347    pub name: String,
348    /// Score (higher is better)
349    pub score: f64,
350    /// Metrics
351    pub metrics: HashMap<String, f64>,
352    /// Execution time
353    pub execution_time: std::time::Duration,
354    /// Memory usage
355    pub memory_usage: usize,
356    /// Additional data
357    pub data: HashMap<String, A>,
358}
359
360/// Benchmark configuration
361#[derive(Debug, Clone)]
362pub struct BenchmarkConfig {
363    /// Number of benchmark runs
364    pub runs: usize,
365    /// Warmup iterations
366    pub warmup_iterations: usize,
367    /// Problem sizes to test
368    pub problem_sizes: Vec<usize>,
369    /// Random seeds
370    pub random_seeds: Vec<u64>,
371}
372
373/// Plugin development helper macros and utilities
374impl PluginSDK {
375    /// Create a plugin template with common functionality
376    pub fn create_plugin_template(name: &str) -> PluginTemplate {
377        PluginTemplate::new(name)
378    }
379
380    /// Validate plugin configuration schema
381    pub fn validate_config_schema(schema: &ConfigSchema) -> Result<()> {
382        for (field_name, field_schema) in &schema.fields {
383            if field_name.is_empty() {
384                return Err(OptimError::InvalidConfig(
385                    "Field name cannot be empty".to_string(),
386                ));
387            }
388
389            if field_schema.description.is_empty() {
390                return Err(OptimError::InvalidConfig(format!(
391                    "Field '{}' must have a description",
392                    field_name
393                )));
394            }
395        }
396        Ok(())
397    }
398
399    /// Generate plugin manifest template
400    pub fn generate_plugin_manifest(info: &PluginInfo) -> String {
401        format!(
402            r#"[plugin]
403name = "{}"
404version = "{}"
405description = "{}"
406author = "{}"
407license = "{}"
408entry_point = "plugin_main"
409
410[build]
411rust_version = "1.70.0"
412target = "*"
413profile = "release"
414
415[runtime]
416min_rust_version = "1.70.0"
417"#,
418            info.name, info.version, info.description, info.author, info.license
419        )
420    }
421
422    /// Create default test configuration
423    pub fn default_test_config() -> TestConfig {
424        TestConfig {
425            iterations: 100,
426            tolerance: 1e-6,
427            random_seed: 42,
428            enable_performance_tests: true,
429            enable_memory_tests: true,
430            enable_convergence_tests: true,
431        }
432    }
433
434    /// Create performance baseline from existing optimizer
435    #[cfg(feature = "cross-platform-testing")]
436    pub fn create_performance_baseline<A>(
437        optimizer: &mut dyn OptimizerPlugin<A>,
438        test_data: &[(Array1<A>, Array1<A>)],
439    ) -> PerformanceBaseline
440    where
441        A: Float + Debug + Send + Sync + 'static,
442    {
443        let start_time = std::time::Instant::now();
444        let mut total_memory = 0;
445
446        for (params, gradients) in test_data {
447            let _result = optimizer.step(params, gradients);
448            total_memory += optimizer.memory_usage().current_usage;
449        }
450
451        let execution_time = start_time.elapsed();
452        let _avg_memory = total_memory / test_data.len();
453
454        PerformanceBaseline {
455            target: PlatformTarget::CPU,
456            throughput_ops_per_sec: test_data.len() as f64 / execution_time.as_secs_f64(),
457            latency_ms: execution_time.as_secs_f64() * 1000.0 / test_data.len() as f64,
458            memory_usage_mb: total_memory as f64 / (1024.0 * 1024.0),
459            energy_consumption_joules: None,
460            accuracy_metrics: HashMap::new(),
461        }
462    }
463}
464
465/// Plugin template for rapid development
466#[derive(Debug)]
467pub struct PluginTemplate {
468    /// Template name
469    name: String,
470    /// Template structure
471    structure: TemplateStructure,
472}
473
474/// Template structure definition
475#[derive(Debug)]
476pub struct TemplateStructure {
477    /// Source files
478    pub source_files: Vec<TemplateFile>,
479    /// Configuration files
480    pub config_files: Vec<TemplateFile>,
481    /// Test files
482    pub test_files: Vec<TemplateFile>,
483    /// Documentation files
484    pub doc_files: Vec<TemplateFile>,
485}
486
487/// Template file
488#[derive(Debug)]
489pub struct TemplateFile {
490    /// File path
491    pub path: String,
492    /// File content
493    pub content: String,
494    /// File type
495    pub file_type: TemplateFileType,
496}
497
498/// Template file types
499#[derive(Debug)]
500pub enum TemplateFileType {
501    /// Rust source file
502    RustSource,
503    /// TOML configuration
504    TomlConfig,
505    /// Markdown documentation
506    Markdown,
507    /// Test file
508    Test,
509}
510
511impl PluginTemplate {
512    /// Create a new plugin template
513    pub fn new(name: &str) -> Self {
514        let structure = Self::create_default_structure(name);
515        Self {
516            name: name.to_string(),
517            structure,
518        }
519    }
520
521    /// Generate template files to directory
522    pub fn generate_to_directory(&self, outputdir: &std::path::Path) -> Result<()> {
523        std::fs::create_dir_all(outputdir)?;
524
525        for file in &self.structure.source_files {
526            let file_path = outputdir.join(&file.path);
527            if let Some(parent) = file_path.parent() {
528                std::fs::create_dir_all(parent)?;
529            }
530            std::fs::write(&file_path, &file.content)?;
531        }
532
533        for file in &self.structure.config_files {
534            let file_path = outputdir.join(&file.path);
535            std::fs::write(&file_path, &file.content)?;
536        }
537
538        for file in &self.structure.test_files {
539            let file_path = outputdir.join(&file.path);
540            if let Some(parent) = file_path.parent() {
541                std::fs::create_dir_all(parent)?;
542            }
543            std::fs::write(&file_path, &file.content)?;
544        }
545
546        Ok(())
547    }
548
549    fn create_default_structure(name: &str) -> TemplateStructure {
550        let lib_rs_content = format!(
551            r#"//! {} optimizer plugin
552//
553// This is an auto-generated plugin template.
554
555use optirs_core::plugin::*;
556use scirs2_core::ndarray::Array1;
557use scirs2_core::numeric::Float;
558
559#[derive(Debug)]
560pub struct {}Optimizer<A: Float> {{
561    learning_rate: A,
562    // Add your optimizer state here
563}}
564
565impl<A: Float + Send + Sync> {}Optimizer<A> {{
566    pub fn new(_learningrate: A) -> Self {{
567        Self {{
568            learning_rate,
569        }}
570    }}
571}}
572
573impl<A: Float + std::fmt::Debug + Send + Sync + 'static> OptimizerPlugin<A> for {}Optimizer<A> {{
574    fn step(&mut self, params: &Array1<A>, gradients: &Array1<A>) -> Result<Array1<A>> {{
575        // Implement your optimization step here
576        Ok(params - &(gradients * self.learning_rate))
577    }}
578    
579    fn name(&self) -> &str {{
580        "{}"
581    }}
582    
583    fn version(&self) -> &str {{
584        "0.1.0"
585    }}
586    
587    fn plugin_info(&self) -> PluginInfo {{
588        create_plugin_info("{}", "0.1.0", "Plugin Developer")
589    }}
590    
591    fn capabilities(&self) -> PluginCapabilities {{
592        create_basic_capabilities()
593    }}
594    
595    fn initialize(&mut self, paramshape: &[usize]) -> Result<()> {{
596        Ok(())
597    }}
598    
599    fn reset(&mut self) -> Result<()> {{
600        Ok(())
601    }}
602    
603    fn get_config(&self) -> OptimizerConfig {{
604        OptimizerConfig::default()
605    }}
606    
607    fn set_config(&mut self, config: OptimizerConfig) -> Result<()> {{
608        Ok(())
609    }}
610    
611    fn get_state(&self) -> Result<OptimizerState> {{
612        Ok(OptimizerState::default())
613    }}
614    
615    fn set_state(&mut self, state: OptimizerState) -> Result<()> {{
616        Ok(())
617    }}
618    
619    fn clone_plugin(&self) -> Box<dyn OptimizerPlugin<A>> {{
620        Box::new(Self::new(self.learning_rate))
621    }}
622}}
623
624// Plugin factory implementation
625#[derive(Debug)]
626pub struct {}Factory;
627
628impl<A: Float + std::fmt::Debug + Send + Sync + 'static> OptimizerPluginFactory<A> for {}Factory {{
629    fn create_optimizer(&self, config: OptimizerConfig) -> Result<Box<dyn OptimizerPlugin<A>>> {{
630        let learning_rate = A::from(config.learning_rate).unwrap();
631        Ok(Box::new({}Optimizer::new(learning_rate)))
632    }}
633    
634    fn factory_info(&self) -> PluginInfo {{
635        create_plugin_info("{}", "0.1.0", "Plugin Developer")
636    }}
637    
638    fn validate_config(&self, config: &OptimizerConfig) -> Result<()> {{
639        if config.learning_rate <= 0.0 {{
640            return Err(OptimError::InvalidConfig(
641                "Learning rate must be positive".to_string(),
642            ));
643        }}
644        Ok(())
645    }}
646    
647    fn default_config(&self) -> OptimizerConfig {{
648        OptimizerConfig {{
649            learning_rate: 0.001,
650            ..Default::default()
651        }}
652    }}
653    
654    fn config_schema(&self) -> ConfigSchema {{
655        let mut schema = ConfigSchema {{
656            fields: std::collections::HashMap::new(),
657            required_fields: vec!["learning_rate".to_string()],
658            version: "1.0".to_string(),
659        }};
660        
661        schema.fields.insert(
662            "learning_rate".to_string(),
663            FieldSchema {{
664                field_type: FieldType::Float {{ min: Some(0.0), max: None }},
665                description: "Learning rate for optimization".to_string(),
666                default_value: Some(ConfigValue::Float(0.001)),
667                constraints: vec![ValidationConstraint::Positive],
668                required: true,
669            }},
670        );
671        
672        schema
673    }}
674}}
675"#,
676            name, name, name, name, name, name, name, name, name, name
677        );
678
679        let plugin_toml_content = format!(
680            r#"[plugin]
681name = "{}"
682version = "0.1.0"
683description = "Custom optimizer plugin"
684author = "Plugin Developer"
685license = "MIT"
686entry_point = "plugin_main"
687
688[build]
689rust_version = "1.70.0"
690target = "*"
691profile = "release"
692
693[runtime]
694min_rust_version = "1.70.0"
695"#,
696            name
697        );
698
699        let test_content = format!(
700            r#"//! Tests for {} optimizer plugin
701
702use super::*;
703use scirs2_core::ndarray::Array1;
704
705#[test]
706#[allow(dead_code)]
707fn test_{}_basic_functionality() {{
708    let mut optimizer = {}Optimizer::new(0.01);
709    let params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
710    let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3]);
711    
712    let result = optimizer.step(&params, &gradients).unwrap();
713    
714    // Verify the result
715    assert!((result[0] - 0.999).abs() < 1e-6);
716    assert!((result[1] - 1.998).abs() < 1e-6);
717    assert!((result[2] - 2.997).abs() < 1e-6);
718}}
719
720#[test]
721#[allow(dead_code)]
722fn test_{}_convergence() {{
723    let mut optimizer = {}Optimizer::new(0.1);
724    let mut params = Array1::from_vec(vec![1.0, 1.0]);
725    
726    // Optimize towards zero
727    for _ in 0..100 {{
728        let gradients = &params * 2.0; // Gradient of x^2
729        params = optimizer.step(&params, &gradients).unwrap();
730    }}
731    
732    // Should converge close to zero
733    assert!(params.iter().all(|&x| x.abs() < 0.1));
734}}
735"#,
736            name,
737            name.to_lowercase(),
738            name,
739            name.to_lowercase(),
740            name
741        );
742
743        TemplateStructure {
744            source_files: vec![TemplateFile {
745                path: "src/lib.rs".to_string(),
746                content: lib_rs_content,
747                file_type: TemplateFileType::RustSource,
748            }],
749            config_files: vec![TemplateFile {
750                path: "plugin.toml".to_string(),
751                content: plugin_toml_content,
752                file_type: TemplateFileType::TomlConfig,
753            }],
754            test_files: vec![TemplateFile {
755                path: "tests/integration_tests.rs".to_string(),
756                content: test_content,
757                file_type: TemplateFileType::Test,
758            }],
759            doc_files: vec![],
760        }
761    }
762}
763
764// Implementation for base optimizer plugin
765
766impl<A: Float + Debug + Send + Sync + 'static> BaseOptimizerPlugin<A> {
767    /// Create a new base optimizer plugin
768    pub fn new(info: PluginInfo, capabilities: PluginCapabilities) -> Self {
769        Self {
770            info,
771            capabilities,
772            config: OptimizerConfig::default(),
773            state: BaseOptimizerState::new(),
774            metrics: PerformanceMetrics::default(),
775            memory_usage: MemoryUsage::default(),
776            event_handlers: Vec::new(),
777        }
778    }
779
780    /// Add event handler
781    pub fn add_event_handler(&mut self, handler: Box<dyn PluginEventHandler>) {
782        self.event_handlers.push(handler);
783    }
784
785    /// Update performance metrics
786    pub fn update_metrics(&mut self, steptime: std::time::Duration) {
787        self.metrics.total_steps += 1;
788        self.metrics.avg_step_time = (self.metrics.avg_step_time
789            * (self.metrics.total_steps - 1) as f64
790            + steptime.as_secs_f64())
791            / self.metrics.total_steps as f64;
792        self.metrics.throughput = 1.0 / self.metrics.avg_step_time;
793    }
794}
795
796impl<A: Float + std::fmt::Debug + Send + Sync> BaseOptimizerState<A> {
797    fn new() -> Self {
798        Self {
799            step_count: 0,
800            param_count: 0,
801            lr_history: Vec::new(),
802            grad_norm_history: Vec::new(),
803            param_change_history: Vec::new(),
804            custom_state: HashMap::new(),
805        }
806    }
807}
808
809// Default implementations
810
811impl Default for TestConfig {
812    fn default() -> Self {
813        Self {
814            iterations: 100,
815            tolerance: 1e-6,
816            random_seed: 42,
817            enable_performance_tests: true,
818            enable_memory_tests: true,
819            enable_convergence_tests: true,
820        }
821    }
822}
823
824impl Default for BenchmarkConfig {
825    fn default() -> Self {
826        Self {
827            runs: 10,
828            warmup_iterations: 5,
829            problem_sizes: vec![10, 100, 1000],
830            random_seeds: vec![42, 123, 456],
831        }
832    }
833}
834
835/// Macro for creating a simple optimizer plugin
836#[macro_export]
837macro_rules! create_optimizer_plugin {
838    ($name:ident, $step_fn:expr) => {
839        #[derive(Debug)]
840        pub struct $name<A: Float> {
841            config: OptimizerConfig,
842            state: OptimizerState,
843            phantom: std::marker::PhantomData<A>,
844        }
845
846        impl<A: Float + Send + Sync> $name<A> {
847            pub fn new() -> Self {
848                Self {
849                    config: OptimizerConfig::default(),
850                    state: OptimizerState::default(),
851                    _phantom: std::marker::PhantomData,
852                }
853            }
854        }
855
856        impl<A: Float + std::fmt::Debug + Send + Sync + 'static> OptimizerPlugin<A> for $name<A> {
857            fn step(&mut self, params: &Array1<A>, gradients: &Array1<A>) -> Result<Array1<A>> {
858                $step_fn(self, params, gradients)
859            }
860
861            fn name(&self) -> &str {
862                stringify!($name)
863            }
864
865            fn version(&self) -> &str {
866                "0.1.0"
867            }
868
869            fn plugin_info(&self) -> PluginInfo {
870                create_plugin_info(stringify!($name), "0.1.0", "Auto-generated")
871            }
872
873            fn capabilities(&self) -> PluginCapabilities {
874                create_basic_capabilities()
875            }
876
877            fn initialize(&mut self, paramshape: &[usize]) -> Result<()> {
878                Ok(())
879            }
880
881            fn reset(&mut self) -> Result<()> {
882                self.state = OptimizerState::default();
883                Ok(())
884            }
885
886            fn get_config(&self) -> OptimizerConfig {
887                self.config.clone()
888            }
889
890            fn set_config(&mut self, config: OptimizerConfig) -> Result<()> {
891                self.config = config;
892                Ok(())
893            }
894
895            fn get_state(&self) -> Result<OptimizerState> {
896                Ok(self.state.clone())
897            }
898
899            fn set_state(&mut self, state: OptimizerState) -> Result<()> {
900                self.state = state;
901                Ok(())
902            }
903
904            fn clone_plugin(&self) -> Box<dyn OptimizerPlugin<A>> {
905                Box::new(Self::new())
906            }
907        }
908    };
909}
910
911#[cfg(test)]
912mod tests {
913    use super::*;
914
915    #[test]
916    fn test_plugin_template_creation() {
917        let template = PluginTemplate::new("TestOptimizer");
918        assert_eq!(template.name, "TestOptimizer");
919        assert!(!template.structure.source_files.is_empty());
920    }
921
922    #[test]
923    fn test_test_config_default() {
924        let config = TestConfig::default();
925        assert_eq!(config.iterations, 100);
926        assert!(config.enable_performance_tests);
927    }
928
929    #[test]
930    fn test_sdk_config_validation() {
931        let mut schema = ConfigSchema {
932            fields: HashMap::new(),
933            required_fields: vec!["test_field".to_string()],
934            version: "1.0".to_string(),
935        };
936
937        schema.fields.insert(
938            "test_field".to_string(),
939            FieldSchema {
940                field_type: FieldType::Float {
941                    min: None,
942                    max: None,
943                },
944                description: "Test field".to_string(),
945                default_value: None,
946                constraints: Vec::new(),
947                required: true,
948            },
949        );
950
951        assert!(PluginSDK::validate_config_schema(&schema).is_ok());
952
953        // Test with empty field name
954        schema.fields.insert(
955            "".to_string(),
956            FieldSchema {
957                field_type: FieldType::Float {
958                    min: None,
959                    max: None,
960                },
961                description: "Test".to_string(),
962                default_value: None,
963                constraints: Vec::new(),
964                required: false,
965            },
966        );
967
968        assert!(PluginSDK::validate_config_schema(&schema).is_err());
969    }
970}