optirs_core/plugin/
sdk.rs

1// Plugin SDK utilities and helpers for optimizer development
2//
3// This module provides a comprehensive SDK for developing custom optimizer plugins,
4// including base classes, utilities, testing frameworks, and development tools.
5
6#[allow(dead_code)]
7use super::core::*;
8use crate::benchmarking::cross_platform_tester::{PerformanceBaseline, PlatformTarget};
9use crate::error::{OptimError, Result};
10use scirs2_core::ndarray::Array1;
11use scirs2_core::numeric::Float;
12use std::collections::HashMap;
13use std::fmt::Debug;
14
15/// Base optimizer plugin implementation with common functionality
16pub struct BaseOptimizerPlugin<A: Float + std::fmt::Debug> {
17    /// Plugin information
18    info: PluginInfo,
19    /// Plugin capabilities
20    capabilities: PluginCapabilities,
21    /// Optimizer configuration
22    config: OptimizerConfig,
23    /// Internal state
24    state: BaseOptimizerState<A>,
25    /// Performance metrics
26    metrics: PerformanceMetrics,
27    /// Memory usage tracking
28    memory_usage: MemoryUsage,
29    /// Event handlers
30    event_handlers: Vec<Box<dyn PluginEventHandler>>,
31}
32
33impl<A: Float + std::fmt::Debug + Send + Sync> std::fmt::Debug for BaseOptimizerPlugin<A> {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        f.debug_struct("BaseOptimizerPlugin")
36            .field("info", &self.info)
37            .field("capabilities", &self.capabilities)
38            .field("config", &self.config)
39            .field("state", &self.state)
40            .field("metrics", &self.metrics)
41            .field("memory_usage", &self.memory_usage)
42            .field(
43                "event_handlers",
44                &format!("{} handlers", self.event_handlers.len()),
45            )
46            .finish()
47    }
48}
49
50/// Base optimizer state
51#[derive(Debug, Clone)]
52pub struct BaseOptimizerState<A: Float + std::fmt::Debug> {
53    /// Step count
54    pub step_count: usize,
55    /// Parameter count
56    pub param_count: usize,
57    /// Learning rate history
58    pub lr_history: Vec<A>,
59    /// Gradient norms history
60    pub grad_norm_history: Vec<A>,
61    /// Parameter change norms history
62    pub param_change_history: Vec<A>,
63    /// Custom state data
64    pub custom_state: HashMap<String, StateValue>,
65}
66
67/// Plugin development utilities
68pub struct PluginSDK;
69
70/// Plugin testing framework
71pub struct PluginTester<A: Float> {
72    /// Test configuration
73    config: TestConfig,
74    /// Test suite
75    test_suite: TestSuite<A>,
76    /// Benchmark suite
77    benchmark_suite: BenchmarkSuite<A>,
78    /// Validation framework
79    validator: PluginValidator<A>,
80}
81
82/// Test configuration
83#[derive(Debug, Clone)]
84pub struct TestConfig {
85    /// Number of test iterations
86    pub iterations: usize,
87    /// Tolerance for numerical tests
88    pub tolerance: f64,
89    /// Random seed for reproducibility
90    pub random_seed: u64,
91    /// Enable performance testing
92    pub enable_performance_tests: bool,
93    /// Enable memory testing
94    pub enable_memory_tests: bool,
95    /// Enable convergence testing
96    pub enable_convergence_tests: bool,
97}
98
99/// Test suite for plugin validation
100#[derive(Debug)]
101pub struct TestSuite<A: Float> {
102    /// Functionality tests
103    pub functionality_tests: Vec<Box<dyn PluginTest<A>>>,
104    /// Performance tests
105    pub performance_tests: Vec<Box<dyn PerformanceTest<A>>>,
106    /// Convergence tests
107    pub convergence_tests: Vec<Box<dyn ConvergenceTest<A>>>,
108    /// Memory tests
109    pub memory_tests: Vec<Box<dyn MemoryTest<A>>>,
110}
111
112/// Individual plugin test trait
113pub trait PluginTest<A: Float>: Debug {
114    /// Run the test
115    fn run_test(&self, plugin: &mut dyn OptimizerPlugin<A>) -> TestResult;
116
117    /// Get test name
118    fn name(&self) -> &str;
119
120    /// Get test description
121    fn description(&self) -> &str;
122}
123
124/// Performance test trait
125pub trait PerformanceTest<A: Float>: Debug {
126    /// Run performance test
127    fn run_performance_test(&self, plugin: &mut dyn OptimizerPlugin<A>) -> PerformanceTestResult;
128
129    /// Get test name
130    fn name(&self) -> &str;
131
132    /// Get performance baseline
133    fn baseline(&self) -> PerformanceBaseline;
134}
135
136/// Convergence test trait
137pub trait ConvergenceTest<A: Float>: Debug {
138    /// Run convergence test
139    fn run_convergence_test(&self, plugin: &mut dyn OptimizerPlugin<A>)
140        -> ConvergenceTestResult<A>;
141
142    /// Get test name
143    fn name(&self) -> &str;
144
145    /// Get convergence criteria
146    fn convergence_criteria(&self) -> ConvergenceCriteria<A>;
147}
148
149/// Memory test trait
150pub trait MemoryTest<A: Float>: Debug {
151    /// Run memory test
152    fn run_memory_test(&self, plugin: &mut dyn OptimizerPlugin<A>) -> MemoryTestResult;
153
154    /// Get test name
155    fn name(&self) -> &str;
156
157    /// Get memory constraints
158    fn memory_constraints(&self) -> MemoryConstraints;
159}
160
161/// Test result
162#[derive(Debug, Clone)]
163pub struct TestResult {
164    /// Test passed
165    pub passed: bool,
166    /// Test message
167    pub message: String,
168    /// Execution time
169    pub execution_time: std::time::Duration,
170    /// Additional data
171    pub data: HashMap<String, serde_json::Value>,
172}
173
174/// Performance test result
175#[derive(Debug, Clone)]
176pub struct PerformanceTestResult {
177    /// Performance metrics
178    pub metrics: PerformanceMetrics,
179    /// Comparison with baseline
180    pub baseline_comparison: BaselineComparison,
181    /// Performance score (0.0 to 1.0)
182    pub performance_score: f64,
183}
184
185/// Convergence test result
186#[derive(Debug, Clone)]
187pub struct ConvergenceTestResult<A: Float> {
188    /// Converged successfully
189    pub converged: bool,
190    /// Number of iterations to convergence
191    pub iterations_to_convergence: Option<usize>,
192    /// Final objective value
193    pub final_objective: A,
194    /// Convergence rate
195    pub convergence_rate: f64,
196    /// Convergence metrics
197    pub metrics: ConvergenceMetrics,
198}
199
200/// Memory test result
201#[derive(Debug, Clone)]
202pub struct MemoryTestResult {
203    /// Memory usage metrics
204    pub memory_metrics: MemoryUsage,
205    /// Memory leak detected
206    pub memory_leak_detected: bool,
207    /// Memory efficiency score
208    pub efficiency_score: f64,
209}
210
211/// Baseline comparison
212#[derive(Debug, Clone)]
213pub struct BaselineComparison {
214    /// Relative performance (baseline = 1.0)
215    pub relative_performance: f64,
216    /// Performance difference (absolute)
217    pub absolute_difference: f64,
218    /// Performance improvement (percentage)
219    pub improvement_percent: f64,
220}
221
222/// Convergence criteria
223#[derive(Debug, Clone)]
224pub struct ConvergenceCriteria<A: Float> {
225    /// Maximum iterations
226    pub max_iterations: usize,
227    /// Gradient norm tolerance
228    pub gradient_tolerance: A,
229    /// Function value tolerance
230    pub function_tolerance: A,
231    /// Parameter change tolerance
232    pub parameter_tolerance: A,
233}
234
235/// Memory constraints for testing
236#[derive(Debug, Clone)]
237pub struct MemoryConstraints {
238    /// Maximum memory usage (bytes)
239    pub max_memory_usage: usize,
240    /// Maximum allocation count
241    pub max_allocations: usize,
242    /// Memory leak tolerance (bytes)
243    pub leak_tolerance: usize,
244}
245
246/// Plugin validator for comprehensive validation
247#[derive(Debug)]
248pub struct PluginValidator<A: Float> {
249    /// Validation rules
250    rules: Vec<Box<dyn ValidationRule<A>>>,
251    /// Compatibility checker
252    compatibility_checker: CompatibilityChecker,
253}
254
255/// Validation rule trait
256pub trait ValidationRule<A: Float>: Debug {
257    /// Validate plugin
258    fn validate(&self, plugin: &dyn OptimizerPlugin<A>) -> ValidationResult;
259
260    /// Get rule name
261    fn name(&self) -> &str;
262
263    /// Get rule severity
264    fn severity(&self) -> ValidationSeverity;
265}
266
267/// Validation result
268#[derive(Debug, Clone)]
269pub struct ValidationResult {
270    /// Validation passed
271    pub passed: bool,
272    /// Validation message
273    pub message: String,
274    /// Severity level
275    pub severity: ValidationSeverity,
276    /// Suggestions for improvement
277    pub suggestions: Vec<String>,
278}
279
280/// Validation severity levels
281#[derive(Debug, Clone)]
282pub enum ValidationSeverity {
283    Info,
284    Warning,
285    Error,
286    Critical,
287}
288
289/// Compatibility checker
290#[derive(Debug)]
291pub struct CompatibilityChecker {
292    /// Target platforms
293    target_platforms: Vec<String>,
294    /// Rust version requirements
295    rust_versions: Vec<String>,
296    /// Dependency compatibility
297    dependency_compatibility: HashMap<String, String>,
298}
299
300/// Benchmark suite for performance evaluation
301#[derive(Debug)]
302pub struct BenchmarkSuite<A: Float> {
303    /// Standard benchmarks
304    standard_benchmarks: Vec<Box<dyn Benchmark<A>>>,
305    /// Custom benchmarks
306    custom_benchmarks: Vec<Box<dyn Benchmark<A>>>,
307    /// Benchmark configuration
308    config: BenchmarkConfig,
309}
310
311/// Benchmark trait
312pub trait Benchmark<A: Float>: Debug {
313    /// Run benchmark
314    fn run_benchmark(&self, plugin: &mut dyn OptimizerPlugin<A>) -> BenchmarkResult<A>;
315
316    /// Get benchmark name
317    fn name(&self) -> &str;
318
319    /// Get benchmark description
320    fn description(&self) -> &str;
321
322    /// Get benchmark category
323    fn category(&self) -> BenchmarkCategory;
324}
325
326/// Benchmark categories
327#[derive(Debug, Clone)]
328pub enum BenchmarkCategory {
329    /// Speed benchmarks
330    Speed,
331    /// Memory benchmarks
332    Memory,
333    /// Accuracy benchmarks
334    Accuracy,
335    /// Scalability benchmarks
336    Scalability,
337    /// Robustness benchmarks
338    Robustness,
339}
340
341/// Benchmark result
342#[derive(Debug, Clone)]
343pub struct BenchmarkResult<A: Float> {
344    /// Benchmark name
345    pub name: String,
346    /// Score (higher is better)
347    pub score: f64,
348    /// Metrics
349    pub metrics: HashMap<String, f64>,
350    /// Execution time
351    pub execution_time: std::time::Duration,
352    /// Memory usage
353    pub memory_usage: usize,
354    /// Additional data
355    pub data: HashMap<String, A>,
356}
357
358/// Benchmark configuration
359#[derive(Debug, Clone)]
360pub struct BenchmarkConfig {
361    /// Number of benchmark runs
362    pub runs: usize,
363    /// Warmup iterations
364    pub warmup_iterations: usize,
365    /// Problem sizes to test
366    pub problem_sizes: Vec<usize>,
367    /// Random seeds
368    pub random_seeds: Vec<u64>,
369}
370
371/// Plugin development helper macros and utilities
372impl PluginSDK {
373    /// Create a plugin template with common functionality
374    pub fn create_plugin_template(name: &str) -> PluginTemplate {
375        PluginTemplate::new(name)
376    }
377
378    /// Validate plugin configuration schema
379    pub fn validate_config_schema(schema: &ConfigSchema) -> Result<()> {
380        for (field_name, field_schema) in &schema.fields {
381            if field_name.is_empty() {
382                return Err(OptimError::InvalidConfig(
383                    "Field name cannot be empty".to_string(),
384                ));
385            }
386
387            if field_schema.description.is_empty() {
388                return Err(OptimError::InvalidConfig(format!(
389                    "Field '{}' must have a description",
390                    field_name
391                )));
392            }
393        }
394        Ok(())
395    }
396
397    /// Generate plugin manifest template
398    pub fn generate_plugin_manifest(info: &PluginInfo) -> String {
399        format!(
400            r#"[plugin]
401name = "{}"
402version = "{}"
403description = "{}"
404author = "{}"
405license = "{}"
406entry_point = "plugin_main"
407
408[build]
409rust_version = "1.70.0"
410target = "*"
411profile = "release"
412
413[runtime]
414min_rust_version = "1.70.0"
415"#,
416            info.name, info.version, info.description, info.author, info.license
417        )
418    }
419
420    /// Create default test configuration
421    pub fn default_test_config() -> TestConfig {
422        TestConfig {
423            iterations: 100,
424            tolerance: 1e-6,
425            random_seed: 42,
426            enable_performance_tests: true,
427            enable_memory_tests: true,
428            enable_convergence_tests: true,
429        }
430    }
431
432    /// Create performance baseline from existing optimizer
433    pub fn create_performance_baseline<A>(
434        optimizer: &mut dyn OptimizerPlugin<A>,
435        test_data: &[(Array1<A>, Array1<A>)],
436    ) -> PerformanceBaseline
437    where
438        A: Float + Debug + Send + Sync + 'static,
439    {
440        let start_time = std::time::Instant::now();
441        let mut total_memory = 0;
442
443        for (params, gradients) in test_data {
444            let _result = optimizer.step(params, gradients);
445            total_memory += optimizer.memory_usage().current_usage;
446        }
447
448        let execution_time = start_time.elapsed();
449        let _avg_memory = total_memory / test_data.len();
450
451        PerformanceBaseline {
452            target: PlatformTarget::CPU,
453            throughput_ops_per_sec: test_data.len() as f64 / execution_time.as_secs_f64(),
454            latency_ms: execution_time.as_secs_f64() * 1000.0 / test_data.len() as f64,
455            memory_usage_mb: total_memory as f64 / (1024.0 * 1024.0),
456            energy_consumption_joules: None,
457            accuracy_metrics: HashMap::new(),
458        }
459    }
460}
461
462/// Plugin template for rapid development
463#[derive(Debug)]
464pub struct PluginTemplate {
465    /// Template name
466    name: String,
467    /// Template structure
468    structure: TemplateStructure,
469}
470
471/// Template structure definition
472#[derive(Debug)]
473pub struct TemplateStructure {
474    /// Source files
475    pub source_files: Vec<TemplateFile>,
476    /// Configuration files
477    pub config_files: Vec<TemplateFile>,
478    /// Test files
479    pub test_files: Vec<TemplateFile>,
480    /// Documentation files
481    pub doc_files: Vec<TemplateFile>,
482}
483
484/// Template file
485#[derive(Debug)]
486pub struct TemplateFile {
487    /// File path
488    pub path: String,
489    /// File content
490    pub content: String,
491    /// File type
492    pub file_type: TemplateFileType,
493}
494
495/// Template file types
496#[derive(Debug)]
497pub enum TemplateFileType {
498    /// Rust source file
499    RustSource,
500    /// TOML configuration
501    TomlConfig,
502    /// Markdown documentation
503    Markdown,
504    /// Test file
505    Test,
506}
507
508impl PluginTemplate {
509    /// Create a new plugin template
510    pub fn new(name: &str) -> Self {
511        let structure = Self::create_default_structure(name);
512        Self {
513            name: name.to_string(),
514            structure,
515        }
516    }
517
518    /// Generate template files to directory
519    pub fn generate_to_directory(&self, outputdir: &std::path::Path) -> Result<()> {
520        std::fs::create_dir_all(outputdir)?;
521
522        for file in &self.structure.source_files {
523            let file_path = outputdir.join(&file.path);
524            if let Some(parent) = file_path.parent() {
525                std::fs::create_dir_all(parent)?;
526            }
527            std::fs::write(&file_path, &file.content)?;
528        }
529
530        for file in &self.structure.config_files {
531            let file_path = outputdir.join(&file.path);
532            std::fs::write(&file_path, &file.content)?;
533        }
534
535        for file in &self.structure.test_files {
536            let file_path = outputdir.join(&file.path);
537            if let Some(parent) = file_path.parent() {
538                std::fs::create_dir_all(parent)?;
539            }
540            std::fs::write(&file_path, &file.content)?;
541        }
542
543        Ok(())
544    }
545
546    fn create_default_structure(name: &str) -> TemplateStructure {
547        let lib_rs_content = format!(
548            r#"//! {} optimizer plugin
549//
550// This is an auto-generated plugin template.
551
552use optirs_core::plugin::*;
553use scirs2_core::ndarray::Array1;
554use scirs2_core::numeric::Float;
555
556#[derive(Debug)]
557pub struct {}Optimizer<A: Float> {{
558    learning_rate: A,
559    // Add your optimizer state here
560}}
561
562impl<A: Float + Send + Sync> {}Optimizer<A> {{
563    pub fn new(_learningrate: A) -> Self {{
564        Self {{
565            learning_rate,
566        }}
567    }}
568}}
569
570impl<A: Float + std::fmt::Debug + Send + Sync + 'static> OptimizerPlugin<A> for {}Optimizer<A> {{
571    fn step(&mut self, params: &Array1<A>, gradients: &Array1<A>) -> Result<Array1<A>> {{
572        // Implement your optimization step here
573        Ok(params - &(gradients * self.learning_rate))
574    }}
575    
576    fn name(&self) -> &str {{
577        "{}"
578    }}
579    
580    fn version(&self) -> &str {{
581        "0.1.0"
582    }}
583    
584    fn plugin_info(&self) -> PluginInfo {{
585        create_plugin_info("{}", "0.1.0", "Plugin Developer")
586    }}
587    
588    fn capabilities(&self) -> PluginCapabilities {{
589        create_basic_capabilities()
590    }}
591    
592    fn initialize(&mut self, paramshape: &[usize]) -> Result<()> {{
593        Ok(())
594    }}
595    
596    fn reset(&mut self) -> Result<()> {{
597        Ok(())
598    }}
599    
600    fn get_config(&self) -> OptimizerConfig {{
601        OptimizerConfig::default()
602    }}
603    
604    fn set_config(&mut self, config: OptimizerConfig) -> Result<()> {{
605        Ok(())
606    }}
607    
608    fn get_state(&self) -> Result<OptimizerState> {{
609        Ok(OptimizerState::default())
610    }}
611    
612    fn set_state(&mut self, state: OptimizerState) -> Result<()> {{
613        Ok(())
614    }}
615    
616    fn clone_plugin(&self) -> Box<dyn OptimizerPlugin<A>> {{
617        Box::new(Self::new(self.learning_rate))
618    }}
619}}
620
621// Plugin factory implementation
622#[derive(Debug)]
623pub struct {}Factory;
624
625impl<A: Float + std::fmt::Debug + Send + Sync + 'static> OptimizerPluginFactory<A> for {}Factory {{
626    fn create_optimizer(&self, config: OptimizerConfig) -> Result<Box<dyn OptimizerPlugin<A>>> {{
627        let learning_rate = A::from(config.learning_rate).unwrap();
628        Ok(Box::new({}Optimizer::new(learning_rate)))
629    }}
630    
631    fn factory_info(&self) -> PluginInfo {{
632        create_plugin_info("{}", "0.1.0", "Plugin Developer")
633    }}
634    
635    fn validate_config(&self, config: &OptimizerConfig) -> Result<()> {{
636        if config.learning_rate <= 0.0 {{
637            return Err(OptimError::InvalidConfig(
638                "Learning rate must be positive".to_string(),
639            ));
640        }}
641        Ok(())
642    }}
643    
644    fn default_config(&self) -> OptimizerConfig {{
645        OptimizerConfig {{
646            learning_rate: 0.001,
647            ..Default::default()
648        }}
649    }}
650    
651    fn config_schema(&self) -> ConfigSchema {{
652        let mut schema = ConfigSchema {{
653            fields: std::collections::HashMap::new(),
654            required_fields: vec!["learning_rate".to_string()],
655            version: "1.0".to_string(),
656        }};
657        
658        schema.fields.insert(
659            "learning_rate".to_string(),
660            FieldSchema {{
661                field_type: FieldType::Float {{ min: Some(0.0), max: None }},
662                description: "Learning rate for optimization".to_string(),
663                default_value: Some(ConfigValue::Float(0.001)),
664                constraints: vec![ValidationConstraint::Positive],
665                required: true,
666            }},
667        );
668        
669        schema
670    }}
671}}
672"#,
673            name, name, name, name, name, name, name, name, name, name
674        );
675
676        let plugin_toml_content = format!(
677            r#"[plugin]
678name = "{}"
679version = "0.1.0"
680description = "Custom optimizer plugin"
681author = "Plugin Developer"
682license = "MIT"
683entry_point = "plugin_main"
684
685[build]
686rust_version = "1.70.0"
687target = "*"
688profile = "release"
689
690[runtime]
691min_rust_version = "1.70.0"
692"#,
693            name
694        );
695
696        let test_content = format!(
697            r#"//! Tests for {} optimizer plugin
698
699use super::*;
700use scirs2_core::ndarray::Array1;
701
702#[test]
703#[allow(dead_code)]
704fn test_{}_basic_functionality() {{
705    let mut optimizer = {}Optimizer::new(0.01);
706    let params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
707    let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3]);
708    
709    let result = optimizer.step(&params, &gradients).unwrap();
710    
711    // Verify the result
712    assert!((result[0] - 0.999).abs() < 1e-6);
713    assert!((result[1] - 1.998).abs() < 1e-6);
714    assert!((result[2] - 2.997).abs() < 1e-6);
715}}
716
717#[test]
718#[allow(dead_code)]
719fn test_{}_convergence() {{
720    let mut optimizer = {}Optimizer::new(0.1);
721    let mut params = Array1::from_vec(vec![1.0, 1.0]);
722    
723    // Optimize towards zero
724    for _ in 0..100 {{
725        let gradients = &params * 2.0; // Gradient of x^2
726        params = optimizer.step(&params, &gradients).unwrap();
727    }}
728    
729    // Should converge close to zero
730    assert!(params.iter().all(|&x| x.abs() < 0.1));
731}}
732"#,
733            name,
734            name.to_lowercase(),
735            name,
736            name.to_lowercase(),
737            name
738        );
739
740        TemplateStructure {
741            source_files: vec![TemplateFile {
742                path: "src/lib.rs".to_string(),
743                content: lib_rs_content,
744                file_type: TemplateFileType::RustSource,
745            }],
746            config_files: vec![TemplateFile {
747                path: "plugin.toml".to_string(),
748                content: plugin_toml_content,
749                file_type: TemplateFileType::TomlConfig,
750            }],
751            test_files: vec![TemplateFile {
752                path: "tests/integration_tests.rs".to_string(),
753                content: test_content,
754                file_type: TemplateFileType::Test,
755            }],
756            doc_files: vec![],
757        }
758    }
759}
760
761// Implementation for base optimizer plugin
762
763impl<A: Float + Debug + Send + Sync + 'static> BaseOptimizerPlugin<A> {
764    /// Create a new base optimizer plugin
765    pub fn new(info: PluginInfo, capabilities: PluginCapabilities) -> Self {
766        Self {
767            info,
768            capabilities,
769            config: OptimizerConfig::default(),
770            state: BaseOptimizerState::new(),
771            metrics: PerformanceMetrics::default(),
772            memory_usage: MemoryUsage::default(),
773            event_handlers: Vec::new(),
774        }
775    }
776
777    /// Add event handler
778    pub fn add_event_handler(&mut self, handler: Box<dyn PluginEventHandler>) {
779        self.event_handlers.push(handler);
780    }
781
782    /// Update performance metrics
783    pub fn update_metrics(&mut self, steptime: std::time::Duration) {
784        self.metrics.total_steps += 1;
785        self.metrics.avg_step_time = (self.metrics.avg_step_time
786            * (self.metrics.total_steps - 1) as f64
787            + steptime.as_secs_f64())
788            / self.metrics.total_steps as f64;
789        self.metrics.throughput = 1.0 / self.metrics.avg_step_time;
790    }
791}
792
793impl<A: Float + std::fmt::Debug + Send + Sync> BaseOptimizerState<A> {
794    fn new() -> Self {
795        Self {
796            step_count: 0,
797            param_count: 0,
798            lr_history: Vec::new(),
799            grad_norm_history: Vec::new(),
800            param_change_history: Vec::new(),
801            custom_state: HashMap::new(),
802        }
803    }
804}
805
806// Default implementations
807
808impl Default for TestConfig {
809    fn default() -> Self {
810        Self {
811            iterations: 100,
812            tolerance: 1e-6,
813            random_seed: 42,
814            enable_performance_tests: true,
815            enable_memory_tests: true,
816            enable_convergence_tests: true,
817        }
818    }
819}
820
821impl Default for BenchmarkConfig {
822    fn default() -> Self {
823        Self {
824            runs: 10,
825            warmup_iterations: 5,
826            problem_sizes: vec![10, 100, 1000],
827            random_seeds: vec![42, 123, 456],
828        }
829    }
830}
831
832/// Macro for creating a simple optimizer plugin
833#[macro_export]
834macro_rules! create_optimizer_plugin {
835    ($name:ident, $step_fn:expr) => {
836        #[derive(Debug)]
837        pub struct $name<A: Float> {
838            config: OptimizerConfig,
839            state: OptimizerState,
840            phantom: std::marker::PhantomData<A>,
841        }
842
843        impl<A: Float + Send + Sync> $name<A> {
844            pub fn new() -> Self {
845                Self {
846                    config: OptimizerConfig::default(),
847                    state: OptimizerState::default(),
848                    _phantom: std::marker::PhantomData,
849                }
850            }
851        }
852
853        impl<A: Float + std::fmt::Debug + Send + Sync + 'static> OptimizerPlugin<A> for $name<A> {
854            fn step(&mut self, params: &Array1<A>, gradients: &Array1<A>) -> Result<Array1<A>> {
855                $step_fn(self, params, gradients)
856            }
857
858            fn name(&self) -> &str {
859                stringify!($name)
860            }
861
862            fn version(&self) -> &str {
863                "0.1.0"
864            }
865
866            fn plugin_info(&self) -> PluginInfo {
867                create_plugin_info(stringify!($name), "0.1.0", "Auto-generated")
868            }
869
870            fn capabilities(&self) -> PluginCapabilities {
871                create_basic_capabilities()
872            }
873
874            fn initialize(&mut self, paramshape: &[usize]) -> Result<()> {
875                Ok(())
876            }
877
878            fn reset(&mut self) -> Result<()> {
879                self.state = OptimizerState::default();
880                Ok(())
881            }
882
883            fn get_config(&self) -> OptimizerConfig {
884                self.config.clone()
885            }
886
887            fn set_config(&mut self, config: OptimizerConfig) -> Result<()> {
888                self.config = config;
889                Ok(())
890            }
891
892            fn get_state(&self) -> Result<OptimizerState> {
893                Ok(self.state.clone())
894            }
895
896            fn set_state(&mut self, state: OptimizerState) -> Result<()> {
897                self.state = state;
898                Ok(())
899            }
900
901            fn clone_plugin(&self) -> Box<dyn OptimizerPlugin<A>> {
902                Box::new(Self::new())
903            }
904        }
905    };
906}
907
908#[cfg(test)]
909mod tests {
910    use super::*;
911
912    #[test]
913    fn test_plugin_template_creation() {
914        let template = PluginTemplate::new("TestOptimizer");
915        assert_eq!(template.name, "TestOptimizer");
916        assert!(!template.structure.source_files.is_empty());
917    }
918
919    #[test]
920    fn test_test_config_default() {
921        let config = TestConfig::default();
922        assert_eq!(config.iterations, 100);
923        assert!(config.enable_performance_tests);
924    }
925
926    #[test]
927    fn test_sdk_config_validation() {
928        let mut schema = ConfigSchema {
929            fields: HashMap::new(),
930            required_fields: vec!["test_field".to_string()],
931            version: "1.0".to_string(),
932        };
933
934        schema.fields.insert(
935            "test_field".to_string(),
936            FieldSchema {
937                field_type: FieldType::Float {
938                    min: None,
939                    max: None,
940                },
941                description: "Test field".to_string(),
942                default_value: None,
943                constraints: Vec::new(),
944                required: true,
945            },
946        );
947
948        assert!(PluginSDK::validate_config_schema(&schema).is_ok());
949
950        // Test with empty field name
951        schema.fields.insert(
952            "".to_string(),
953            FieldSchema {
954                field_type: FieldType::Float {
955                    min: None,
956                    max: None,
957                },
958                description: "Test".to_string(),
959                default_value: None,
960                constraints: Vec::new(),
961                required: false,
962            },
963        );
964
965        assert!(PluginSDK::validate_config_schema(&schema).is_err());
966    }
967}