optirs_core/
self_tuning.rs

1// Self-tuning optimization strategies
2//
3// This module provides adaptive optimization strategies that automatically
4// tune hyperparameters, select optimizers, and adjust configurations based
5// on training dynamics and problem characteristics.
6
7#[allow(dead_code)]
8use crate::error::Result;
9use crate::optimizers::*;
10use crate::schedulers::*;
11use scirs2_core::ndarray::{Array, Dimension, ScalarOperand};
12use scirs2_core::numeric::Float;
13use scirs2_core::random::{thread_rng, Rng};
14use std::collections::{HashMap, VecDeque};
15use std::fmt::Debug;
16use std::time::{Duration, Instant};
17
18/// Configuration for self-tuning optimization
19#[derive(Debug, Clone)]
20pub struct SelfTuningConfig {
21    /// Window size for performance evaluation
22    pub evaluation_window: usize,
23
24    /// Minimum improvement threshold for parameter updates
25    pub improvement_threshold: f64,
26
27    /// Maximum number of optimizer switches per epoch
28    pub max_switches_per_epoch: usize,
29
30    /// Enable automatic learning rate adjustment
31    pub auto_lr_adjustment: bool,
32
33    /// Enable automatic optimizer selection
34    pub auto_optimizer_selection: bool,
35
36    /// Enable automatic batch size tuning
37    pub auto_batch_size_tuning: bool,
38
39    /// Warmup period before starting adaptations
40    pub warmup_steps: usize,
41
42    /// Exploration probability for optimizer selection
43    pub exploration_rate: f64,
44
45    /// Decay rate for exploration
46    pub exploration_decay: f64,
47
48    /// Performance metric to optimize
49    pub target_metric: TargetMetric,
50}
51
52impl Default for SelfTuningConfig {
53    fn default() -> Self {
54        Self {
55            evaluation_window: 100,
56            improvement_threshold: 0.01,
57            max_switches_per_epoch: 3,
58            auto_lr_adjustment: true,
59            auto_optimizer_selection: true,
60            auto_batch_size_tuning: false,
61            warmup_steps: 1000,
62            exploration_rate: 0.1,
63            exploration_decay: 0.99,
64            target_metric: TargetMetric::Loss,
65        }
66    }
67}
68
69/// Target optimization metric
70#[derive(Debug, Clone, Copy, PartialEq)]
71pub enum TargetMetric {
72    /// Minimize loss
73    Loss,
74    /// Maximize accuracy
75    Accuracy,
76    /// Minimize convergence time
77    ConvergenceTime,
78    /// Maximize training throughput
79    Throughput,
80    /// Custom metric (user-defined)
81    Custom,
82}
83
84/// Performance statistics for tracking optimization progress
85#[derive(Debug, Clone)]
86pub struct PerformanceStats {
87    /// Current loss value
88    pub loss: f64,
89
90    /// Current accuracy (if available)
91    pub accuracy: Option<f64>,
92
93    /// Gradient norm
94    pub gradient_norm: f64,
95
96    /// Training throughput (samples/second)
97    pub throughput: f64,
98
99    /// Memory usage (MB)
100    pub memory_usage: f64,
101
102    /// Wall clock time for this step
103    pub step_time: Duration,
104
105    /// Learning rate used
106    pub learning_rate: f64,
107
108    /// Optimizer type used
109    pub optimizer_type: String,
110
111    /// Custom metrics
112    pub custom_metrics: HashMap<String, f64>,
113}
114
115/// Adaptive optimizer that automatically tunes hyperparameters
116pub struct SelfTuningOptimizer<A: Float, D: Dimension> {
117    /// Configuration
118    config: SelfTuningConfig,
119
120    /// Current active optimizer
121    current_optimizer: Box<dyn OptimizerTrait<A, D>>,
122
123    /// Available optimizer candidates
124    optimizer_candidates: Vec<OptimizerCandidate<A, D>>,
125
126    /// Performance history
127    performance_history: VecDeque<PerformanceStats>,
128
129    /// Hyperparameter search state
130    search_state: HyperparameterSearchState,
131
132    /// Learning rate scheduler
133    lr_scheduler: Option<Box<dyn LearningRateScheduler<A>>>,
134
135    /// Optimizer selection strategy
136    selection_strategy: OptimizerSelectionStrategy,
137
138    /// Current step count
139    step_count: usize,
140
141    /// Number of optimizer switches in current epoch
142    switches_this_epoch: usize,
143
144    /// Best performance seen so far
145    best_performance: Option<f64>,
146
147    /// Time of last adaptation
148    last_adaptation_time: Instant,
149
150    /// Exploration state for multi-armed bandit
151    bandit_state: BanditState,
152}
153
154/// Optimizer candidate with its configuration
155struct OptimizerCandidate<A: Float, D: Dimension> {
156    /// Name/identifier
157    name: String,
158
159    /// Factory function to create the optimizer
160    factory: Box<dyn Fn() -> Box<dyn OptimizerTrait<A, D>>>,
161
162    /// Performance history for this optimizer
163    performance_history: Vec<f64>,
164
165    /// Usage count
166    usage_count: usize,
167
168    /// Average performance
169    average_performance: f64,
170
171    /// Confidence interval
172    confidence_interval: (f64, f64),
173}
174
175/// Hyperparameter search state
176#[derive(Debug)]
177struct HyperparameterSearchState {
178    /// Current learning rate
179    learning_rate: f64,
180
181    /// Learning rate search bounds
182    lr_bounds: (f64, f64),
183
184    /// Current batch size
185    batch_size: usize,
186
187    /// Batch size search bounds
188    batch_size_bounds: (usize, usize),
189
190    /// Number of search iterations performed
191    search_iterations: usize,
192
193    /// Best hyperparameters found
194    best_hyperparameters: HashMap<String, f64>,
195
196    /// Search algorithm state
197    search_algorithm: SearchAlgorithm,
198}
199
200/// Hyperparameter search algorithms
201#[derive(Debug)]
202enum SearchAlgorithm {
203    /// Random search
204    Random {
205        /// Random number generator seed
206        seed: u64,
207    },
208
209    /// Bayesian optimization
210    Bayesian {
211        /// Gaussian process state
212        gp_state: GaussianProcessState,
213    },
214
215    /// Grid search
216    Grid {
217        /// Current grid position
218        position: Vec<usize>,
219        /// Grid dimensions
220        dimensions: Vec<usize>,
221    },
222
223    /// Successive halving (Hyperband)
224    SuccessiveHalving {
225        /// Current bracket
226        bracket: usize,
227        /// Configurations in current round
228        configurations: Vec<HashMap<String, f64>>,
229    },
230}
231
232/// Gaussian process state for Bayesian optimization
233#[derive(Debug)]
234struct GaussianProcessState {
235    /// Observed points
236    observed_points: Vec<Vec<f64>>,
237
238    /// Observed values
239    observed_values: Vec<f64>,
240
241    /// Kernel hyperparameters
242    kernel_params: Vec<f64>,
243
244    /// Acquisition function type
245    acquisition_function: AcquisitionFunction,
246}
247
248/// Acquisition functions for Bayesian optimization
249#[derive(Debug, Clone, Copy)]
250enum AcquisitionFunction {
251    ExpectedImprovement,
252    ProbabilityOfImprovement,
253    UpperConfidenceBound,
254}
255
256/// Optimizer selection strategies
257#[derive(Debug, Clone)]
258enum OptimizerSelectionStrategy {
259    /// Multi-armed bandit approach
260    MultiArmedBandit {
261        /// Bandit algorithm type
262        algorithm: BanditAlgorithm,
263    },
264
265    /// Performance-based selection
266    PerformanceBased {
267        /// Minimum performance difference for switching
268        min_difference: f64,
269    },
270
271    /// Round-robin testing
272    RoundRobin {
273        /// Current optimizer index
274        current_index: usize,
275    },
276
277    /// Meta-learning based selection
278    MetaLearning {
279        /// Problem characteristics
280        problem_features: Vec<f64>,
281        /// Learned optimizer mappings
282        optimizer_mappings: HashMap<String, f64>,
283    },
284}
285
286/// Multi-armed bandit algorithms
287#[derive(Debug, Clone, Copy)]
288enum BanditAlgorithm {
289    EpsilonGreedy,
290    UCB1,
291    ThompsonSampling,
292    LinUCB,
293}
294
295/// Multi-armed bandit state
296#[derive(Debug)]
297struct BanditState {
298    /// Reward estimates for each optimizer
299    reward_estimates: Vec<f64>,
300
301    /// Confidence bounds
302    confidence_bounds: Vec<f64>,
303
304    /// Selection counts
305    selection_counts: Vec<usize>,
306
307    /// Total selections
308    total_selections: usize,
309
310    /// Exploration parameter
311    exploration_param: f64,
312}
313
314/// Trait for optimizer implementations that can be used with self-tuning
315pub trait OptimizerTrait<A: Float + ScalarOperand + Debug, D: Dimension>: Send + Sync {
316    /// Get optimizer name
317    fn name(&self) -> &str;
318
319    /// Perform optimization step
320    fn step(&mut self, params: &mut [Array<A, D>], grads: &[Array<A, D>]) -> Result<()>;
321
322    /// Get current learning rate
323    fn learning_rate(&self) -> A;
324
325    /// Set learning rate
326    fn set_learning_rate(&mut self, lr: A);
327
328    /// Get optimizer state for serialization
329    fn get_state(&self) -> HashMap<String, Vec<u8>>;
330
331    /// Set optimizer state from serialization
332    fn set_state(&mut self, state: HashMap<String, Vec<u8>>) -> Result<()>;
333
334    /// Clone the optimizer
335    fn clone_optimizer(&self) -> Box<dyn OptimizerTrait<A, D>>;
336}
337
338impl<
339        A: Float + ScalarOperand + Debug + Send + Sync + 'static + scirs2_core::numeric::FromPrimitive,
340        D: Dimension + 'static,
341    > SelfTuningOptimizer<A, D>
342{
343    /// Create new self-tuning optimizer
344    pub fn new(config: SelfTuningConfig) -> Result<Self> {
345        let mut optimizer_candidates = Vec::new();
346
347        // Add default optimizer candidates
348        optimizer_candidates.push(OptimizerCandidate {
349            name: "Adam".to_string(),
350            factory: Box::new(|| Box::new(AdamOptimizerWrapper::new(0.001, 0.9, 0.999, 1e-8, 0.0))),
351            performance_history: Vec::new(),
352            usage_count: 0,
353            average_performance: 0.0,
354            confidence_interval: (0.0, 0.0),
355        });
356
357        optimizer_candidates.push(OptimizerCandidate {
358            name: "SGD".to_string(),
359            factory: Box::new(|| Box::new(SGDOptimizerWrapper::new(0.01, 0.9, 0.0, false))),
360            performance_history: Vec::new(),
361            usage_count: 0,
362            average_performance: 0.0,
363            confidence_interval: (0.0, 0.0),
364        });
365
366        optimizer_candidates.push(OptimizerCandidate {
367            name: "AdamW".to_string(),
368            factory: Box::new(|| {
369                Box::new(AdamWOptimizerWrapper::new(0.001, 0.9, 0.999, 1e-8, 0.01))
370            }),
371            performance_history: Vec::new(),
372            usage_count: 0,
373            average_performance: 0.0,
374            confidence_interval: (0.0, 0.0),
375        });
376
377        // Start with Adam as default
378        let current_optimizer = (optimizer_candidates[0].factory)();
379
380        let search_state = HyperparameterSearchState {
381            learning_rate: 0.001,
382            lr_bounds: (1e-6, 1.0),
383            batch_size: 32,
384            batch_size_bounds: (8, 512),
385            search_iterations: 0,
386            best_hyperparameters: HashMap::new(),
387            search_algorithm: SearchAlgorithm::Random { seed: 42 },
388        };
389
390        let selection_strategy = OptimizerSelectionStrategy::MultiArmedBandit {
391            algorithm: BanditAlgorithm::UCB1,
392        };
393
394        let bandit_state = BanditState {
395            reward_estimates: vec![0.0; optimizer_candidates.len()],
396            confidence_bounds: vec![1.0; optimizer_candidates.len()],
397            selection_counts: vec![0; optimizer_candidates.len()],
398            total_selections: 0,
399            exploration_param: 2.0,
400        };
401
402        Ok(Self {
403            config,
404            current_optimizer,
405            optimizer_candidates,
406            performance_history: VecDeque::new(),
407            search_state,
408            lr_scheduler: None,
409            selection_strategy,
410            step_count: 0,
411            switches_this_epoch: 0,
412            best_performance: None,
413            last_adaptation_time: Instant::now(),
414            bandit_state,
415        })
416    }
417
418    /// Add a custom optimizer candidate
419    pub fn add_optimizer_candidate<F>(&mut self, name: String, factory: F)
420    where
421        F: Fn() -> Box<dyn OptimizerTrait<A, D>> + 'static,
422    {
423        self.optimizer_candidates.push(OptimizerCandidate {
424            name,
425            factory: Box::new(factory),
426            performance_history: Vec::new(),
427            usage_count: 0,
428            average_performance: 0.0,
429            confidence_interval: (0.0, 0.0),
430        });
431
432        // Update bandit state
433        self.bandit_state.reward_estimates.push(0.0);
434        self.bandit_state.confidence_bounds.push(1.0);
435        self.bandit_state.selection_counts.push(0);
436    }
437
438    /// Perform optimization step with automatic tuning
439    pub fn step(
440        &mut self,
441        params: &mut [Array<A, D>],
442        grads: &[Array<A, D>],
443        stats: PerformanceStats,
444    ) -> Result<()> {
445        self.step_count += 1;
446
447        // Record performance
448        self.performance_history.push_back(stats.clone());
449        if self.performance_history.len() > self.config.evaluation_window {
450            self.performance_history.pop_front();
451        }
452
453        // Perform optimization step
454        self.current_optimizer.step(params, grads)?;
455
456        // Self-tuning adaptations
457        if self.step_count > self.config.warmup_steps {
458            self.maybe_adapt_optimizer(&stats)?;
459            self.maybe_adapt_learning_rate(&stats)?;
460            self.maybe_adapt_hyperparameters(&stats)?;
461        }
462
463        // Update best performance
464        let current_performance = self.extract_performance_metric(&stats);
465        if let Some(performance) = current_performance {
466            if self.best_performance.is_none()
467                || self.is_better_performance(performance, self.best_performance.unwrap())
468            {
469                self.best_performance = Some(performance);
470            }
471        }
472
473        Ok(())
474    }
475
476    /// Check if we should adapt the optimizer
477    fn maybe_adapt_optimizer(&mut self, stats: &PerformanceStats) -> Result<()> {
478        if !self.config.auto_optimizer_selection {
479            return Ok(());
480        }
481
482        if self.switches_this_epoch >= self.config.max_switches_per_epoch {
483            return Ok(());
484        }
485
486        let should_adapt = self.should_adapt_optimizer(stats);
487
488        if should_adapt {
489            self.adapt_optimizer(stats)?;
490            self.switches_this_epoch += 1;
491        }
492
493        Ok(())
494    }
495
496    /// Determine if optimizer should be adapted
497    fn should_adapt_optimizer(&self, stats: &PerformanceStats) -> bool {
498        if self.performance_history.len() < self.config.evaluation_window / 2 {
499            return false;
500        }
501
502        // Check for performance degradation or stagnation
503        let recent_performance: Vec<f64> = self
504            .performance_history
505            .iter()
506            .rev()
507            .take(self.config.evaluation_window / 4)
508            .filter_map(|s| self.extract_performance_metric(s))
509            .collect();
510
511        let older_performance: Vec<f64> = self
512            .performance_history
513            .iter()
514            .rev()
515            .skip(self.config.evaluation_window / 4)
516            .take(self.config.evaluation_window / 4)
517            .filter_map(|s| self.extract_performance_metric(s))
518            .collect();
519
520        if recent_performance.is_empty() || older_performance.is_empty() {
521            return false;
522        }
523
524        let recent_avg = recent_performance.iter().sum::<f64>() / recent_performance.len() as f64;
525        let older_avg = older_performance.iter().sum::<f64>() / older_performance.len() as f64;
526
527        // Check for stagnation or degradation
528        match self.config.target_metric {
529            TargetMetric::Loss => {
530                (recent_avg - older_avg).abs() < self.config.improvement_threshold
531                    || recent_avg > older_avg
532            }
533            TargetMetric::Accuracy | TargetMetric::Throughput => {
534                (recent_avg - older_avg).abs() < self.config.improvement_threshold
535                    || recent_avg < older_avg
536            }
537            _ => false,
538        }
539    }
540
541    /// Adapt the optimizer based on performance
542    fn adapt_optimizer(&mut self, stats: &PerformanceStats) -> Result<()> {
543        let new_optimizer_idx = match &self.selection_strategy {
544            OptimizerSelectionStrategy::MultiArmedBandit { algorithm } => {
545                self.select_optimizer_bandit(*algorithm)
546            }
547            OptimizerSelectionStrategy::PerformanceBased { .. } => {
548                self.select_optimizer_performance_based()
549            }
550            OptimizerSelectionStrategy::RoundRobin { current_index } => {
551                (*current_index + 1) % self.optimizer_candidates.len()
552            }
553            OptimizerSelectionStrategy::MetaLearning { .. } => {
554                self.select_optimizer_meta_learning(stats)
555            }
556        };
557
558        // Switch to new optimizer
559        if new_optimizer_idx < self.optimizer_candidates.len() {
560            let current_lr = self.current_optimizer.learning_rate();
561            let current_state = self.current_optimizer.get_state();
562
563            self.current_optimizer = (self.optimizer_candidates[new_optimizer_idx].factory)();
564            self.current_optimizer.set_learning_rate(current_lr);
565
566            // Try to transfer compatible state
567            if self.current_optimizer.set_state(current_state).is_err() {
568                // State transfer failed, continue with fresh state
569            }
570
571            // Update usage statistics
572            self.optimizer_candidates[new_optimizer_idx].usage_count += 1;
573        }
574
575        Ok(())
576    }
577
578    /// Select optimizer using multi-armed bandit
579    fn select_optimizer_bandit(&mut self, algorithm: BanditAlgorithm) -> usize {
580        match algorithm {
581            BanditAlgorithm::UCB1 => self.select_ucb1(),
582            BanditAlgorithm::EpsilonGreedy => self.select_epsilon_greedy(),
583            BanditAlgorithm::ThompsonSampling => self.select_thompson_sampling(),
584            BanditAlgorithm::LinUCB => self.select_linucb(),
585        }
586    }
587
588    /// UCB1 optimizer selection
589    fn select_ucb1(&self) -> usize {
590        if self.bandit_state.total_selections == 0 {
591            return 0;
592        }
593
594        let mut best_score = f64::NEG_INFINITY;
595        let mut best_idx = 0;
596
597        for (i, candidate) in self.optimizer_candidates.iter().enumerate() {
598            let ucb_score = if self.bandit_state.selection_counts[i] == 0 {
599                f64::INFINITY
600            } else {
601                let mean_reward = self.bandit_state.reward_estimates[i];
602                let confidence = (self.bandit_state.exploration_param
603                    * (self.bandit_state.total_selections as f64).ln()
604                    / self.bandit_state.selection_counts[i] as f64)
605                    .sqrt();
606                mean_reward + confidence
607            };
608
609            if ucb_score > best_score {
610                best_score = ucb_score;
611                best_idx = i;
612            }
613        }
614
615        best_idx
616    }
617
618    /// Epsilon-greedy optimizer selection
619    fn select_epsilon_greedy(&self) -> usize {
620        let mut rng = thread_rng();
621
622        if A::from(rng.random::<f64>()).unwrap() < A::from(self.config.exploration_rate).unwrap() {
623            // Explore: random selection
624            rng.gen_range(0..self.optimizer_candidates.len())
625        } else {
626            // Exploit: best performing optimizer
627            self.bandit_state
628                .reward_estimates
629                .iter()
630                .enumerate()
631                .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
632                .map(|(idx, _)| idx)
633                .unwrap_or(0)
634        }
635    }
636
637    /// Thompson sampling optimizer selection
638    fn select_thompson_sampling(&self) -> usize {
639        // Simplified Thompson sampling - in practice would use Beta distributions
640        let mut rng = thread_rng();
641
642        let mut best_sample = f64::NEG_INFINITY;
643        let mut best_idx = 0;
644
645        for (i, _) in self.optimizer_candidates.iter().enumerate() {
646            let mean = self.bandit_state.reward_estimates[i];
647            let std = self.bandit_state.confidence_bounds[i];
648            let sample = rng.gen_range(mean - std..mean + std);
649
650            if sample > best_sample {
651                best_sample = sample;
652                best_idx = i;
653            }
654        }
655
656        best_idx
657    }
658
659    /// LinUCB optimizer selection (contextual bandit)
660    fn select_linucb(&self) -> usize {
661        // Simplified LinUCB - would use contextual features in practice
662        self.select_ucb1()
663    }
664
665    /// Performance-based optimizer selection
666    fn select_optimizer_performance_based(&self) -> usize {
667        self.optimizer_candidates
668            .iter()
669            .enumerate()
670            .max_by(|a, b| {
671                a.1.average_performance
672                    .partial_cmp(&b.1.average_performance)
673                    .unwrap()
674            })
675            .map(|(idx, _)| idx)
676            .unwrap_or(0)
677    }
678
679    /// Meta-learning based optimizer selection
680    fn select_optimizer_meta_learning(&self, stats: &PerformanceStats) -> usize {
681        // Simplified meta-learning - would use problem features in practice
682        0
683    }
684
685    /// Adapt learning rate based on performance
686    fn maybe_adapt_learning_rate(&mut self, stats: &PerformanceStats) -> Result<()> {
687        if !self.config.auto_lr_adjustment {
688            return Ok(());
689        }
690
691        // Simple adaptive learning rate based on gradient norm
692        let current_lr = self.current_optimizer.learning_rate().to_f64().unwrap();
693        let gradient_norm = stats.gradient_norm;
694
695        let new_lr = if gradient_norm > 10.0 {
696            // Large gradients - reduce learning rate
697            current_lr * 0.9
698        } else if gradient_norm < 0.1 {
699            // Small gradients - increase learning rate
700            current_lr * 1.1
701        } else {
702            current_lr
703        };
704
705        let clamped_lr = new_lr
706            .max(self.search_state.lr_bounds.0)
707            .min(self.search_state.lr_bounds.1);
708
709        if (clamped_lr - current_lr).abs() > current_lr * 0.01 {
710            self.current_optimizer
711                .set_learning_rate(A::from(clamped_lr).unwrap());
712            self.search_state.learning_rate = clamped_lr;
713        }
714
715        Ok(())
716    }
717
718    /// Adapt other hyperparameters
719    fn maybe_adapt_hyperparameters(&mut self, stats: &PerformanceStats) -> Result<()> {
720        // Placeholder for hyperparameter adaptation
721        // Would implement Bayesian optimization, random search, etc.
722        Ok(())
723    }
724
725    /// Extract performance metric from stats
726    fn extract_performance_metric(&self, stats: &PerformanceStats) -> Option<f64> {
727        match self.config.target_metric {
728            TargetMetric::Loss => Some(stats.loss),
729            TargetMetric::Accuracy => stats.accuracy,
730            TargetMetric::Throughput => Some(stats.throughput),
731            TargetMetric::ConvergenceTime => Some(stats.step_time.as_secs_f64()),
732            TargetMetric::Custom => stats.custom_metrics.values().next().copied(),
733        }
734    }
735
736    /// Check if performance is better
737    fn is_better_performance(&self, new_perf: f64, oldperf: f64) -> bool {
738        match self.config.target_metric {
739            TargetMetric::Loss | TargetMetric::ConvergenceTime => new_perf < oldperf,
740            TargetMetric::Accuracy | TargetMetric::Throughput => new_perf > oldperf,
741            TargetMetric::Custom => new_perf > oldperf, // Assume higher is better for custom
742        }
743    }
744
745    /// Reset epoch counters
746    pub fn reset_epoch(&mut self) {
747        self.switches_this_epoch = 0;
748    }
749
750    /// Get current optimizer information
751    pub fn get_optimizer_info(&self) -> OptimizerInfo {
752        OptimizerInfo {
753            name: self.current_optimizer.name().to_string(),
754            learning_rate: self.current_optimizer.learning_rate().to_f64().unwrap(),
755            step_count: self.step_count,
756            switches_this_epoch: self.switches_this_epoch,
757            performance_window_size: self.performance_history.len(),
758            best_performance: self.best_performance,
759        }
760    }
761
762    /// Get optimization statistics
763    pub fn get_statistics(&self) -> SelfTuningStatistics {
764        let optimizer_usage: HashMap<String, usize> = self
765            .optimizer_candidates
766            .iter()
767            .map(|c| (c.name.clone(), c.usage_count))
768            .collect();
769
770        SelfTuningStatistics {
771            total_steps: self.step_count,
772            total_optimizer_switches: self
773                .optimizer_candidates
774                .iter()
775                .map(|c| c.usage_count)
776                .sum(),
777            optimizer_usage,
778            current_learning_rate: self.search_state.learning_rate,
779            average_step_time: self
780                .performance_history
781                .iter()
782                .map(|s| s.step_time.as_secs_f64())
783                .sum::<f64>()
784                / self.performance_history.len().max(1) as f64,
785            exploration_rate: self.config.exploration_rate,
786        }
787    }
788}
789
790/// Information about current optimizer state
791#[derive(Debug, Clone)]
792pub struct OptimizerInfo {
793    pub name: String,
794    pub learning_rate: f64,
795    pub step_count: usize,
796    pub switches_this_epoch: usize,
797    pub performance_window_size: usize,
798    pub best_performance: Option<f64>,
799}
800
801/// Statistics about self-tuning optimization
802#[derive(Debug, Clone)]
803pub struct SelfTuningStatistics {
804    pub total_steps: usize,
805    pub total_optimizer_switches: usize,
806    pub optimizer_usage: HashMap<String, usize>,
807    pub current_learning_rate: f64,
808    pub average_step_time: f64,
809    pub exploration_rate: f64,
810}
811
812// Wrapper implementations for existing optimizers
813struct AdamOptimizerWrapper<A: Float + ScalarOperand + Debug, D: Dimension> {
814    inner: crate::optimizers::Adam<A>,
815    _phantom: std::marker::PhantomData<D>,
816}
817
818impl<A: Float + ScalarOperand + Debug + Send + Sync, D: Dimension + Send + Sync>
819    AdamOptimizerWrapper<A, D>
820{
821    fn new(_lr: f64, beta1: f64, beta2: f64, eps: f64, weightdecay: f64) -> Self {
822        Self {
823            inner: crate::optimizers::Adam::new_with_config(
824                A::from(_lr).unwrap(),
825                A::from(beta1).unwrap(),
826                A::from(beta2).unwrap(),
827                A::from(eps).unwrap(),
828                A::from(weightdecay).unwrap(),
829            ),
830            _phantom: std::marker::PhantomData,
831        }
832    }
833}
834
835impl<A: Float + ScalarOperand + Debug + Send + Sync + 'static, D: Dimension + 'static>
836    OptimizerTrait<A, D> for AdamOptimizerWrapper<A, D>
837{
838    fn name(&self) -> &str {
839        "Adam"
840    }
841
842    fn step(&mut self, params: &mut [Array<A, D>], grads: &[Array<A, D>]) -> Result<()> {
843        if params.len() != grads.len() {
844            return Err(crate::error::OptimError::InvalidParameter(
845                "Mismatched number of parameters and gradients".to_string(),
846            ));
847        }
848
849        for (param, grad) in params.iter_mut().zip(grads.iter()) {
850            let updated = self.inner.step(param, grad)?;
851            *param = updated;
852        }
853        Ok(())
854    }
855
856    fn learning_rate(&self) -> A {
857        self.inner.learning_rate()
858    }
859
860    fn set_learning_rate(&mut self, lr: A) {
861        <crate::optimizers::Adam<A> as crate::optimizers::Optimizer<A, D>>::set_learning_rate(
862            &mut self.inner,
863            lr,
864        );
865    }
866
867    fn get_state(&self) -> HashMap<String, Vec<u8>> {
868        HashMap::new()
869    }
870
871    fn set_state(&mut self, state: HashMap<String, Vec<u8>>) -> Result<()> {
872        Ok(())
873    }
874
875    fn clone_optimizer(&self) -> Box<dyn OptimizerTrait<A, D>> {
876        Box::new(AdamOptimizerWrapper {
877            inner: self.inner.clone(),
878            _phantom: std::marker::PhantomData,
879        })
880    }
881}
882
883struct SGDOptimizerWrapper<A: Float + ScalarOperand + Debug, D: Dimension> {
884    inner: crate::optimizers::SGD<A>,
885    _phantom: std::marker::PhantomData<D>,
886}
887
888impl<A: Float + ScalarOperand + Debug + Send + Sync, D: Dimension + Send + Sync>
889    SGDOptimizerWrapper<A, D>
890{
891    fn new(_lr: f64, momentum: f64, weightdecay: f64, nesterov: bool) -> Self {
892        Self {
893            inner: crate::optimizers::SGD::new_with_config(
894                A::from(_lr).unwrap(),
895                A::from(momentum).unwrap(),
896                A::from(weightdecay).unwrap(),
897            ),
898            _phantom: std::marker::PhantomData,
899        }
900    }
901}
902
903impl<A: Float + ScalarOperand + Debug + Send + Sync + 'static, D: Dimension + 'static>
904    OptimizerTrait<A, D> for SGDOptimizerWrapper<A, D>
905{
906    fn name(&self) -> &str {
907        "SGD"
908    }
909
910    fn step(&mut self, params: &mut [Array<A, D>], grads: &[Array<A, D>]) -> Result<()> {
911        if params.len() != grads.len() {
912            return Err(crate::error::OptimError::InvalidParameter(
913                "Mismatched number of parameters and gradients".to_string(),
914            ));
915        }
916
917        for (param, grad) in params.iter_mut().zip(grads.iter()) {
918            let updated = self.inner.step(param, grad)?;
919            *param = updated;
920        }
921        Ok(())
922    }
923
924    fn learning_rate(&self) -> A {
925        self.inner.learning_rate()
926    }
927
928    fn set_learning_rate(&mut self, lr: A) {
929        <crate::optimizers::SGD<A> as crate::optimizers::Optimizer<A, D>>::set_learning_rate(
930            &mut self.inner,
931            lr,
932        );
933    }
934
935    fn get_state(&self) -> HashMap<String, Vec<u8>> {
936        HashMap::new()
937    }
938
939    fn set_state(&mut self, state: HashMap<String, Vec<u8>>) -> Result<()> {
940        Ok(())
941    }
942
943    fn clone_optimizer(&self) -> Box<dyn OptimizerTrait<A, D>> {
944        Box::new(SGDOptimizerWrapper {
945            inner: self.inner.clone(),
946            _phantom: std::marker::PhantomData,
947        })
948    }
949}
950
951struct AdamWOptimizerWrapper<A: Float + ScalarOperand + Debug, D: Dimension> {
952    inner: crate::optimizers::AdamW<A>,
953    _phantom: std::marker::PhantomData<D>,
954}
955
956impl<A: Float + ScalarOperand + Debug + Send + Sync, D: Dimension + Send + Sync>
957    AdamWOptimizerWrapper<A, D>
958{
959    fn new(_lr: f64, beta1: f64, beta2: f64, eps: f64, weightdecay: f64) -> Self {
960        Self {
961            inner: crate::optimizers::AdamW::new_with_config(
962                A::from(_lr).unwrap(),
963                A::from(beta1).unwrap(),
964                A::from(beta2).unwrap(),
965                A::from(eps).unwrap(),
966                A::from(weightdecay).unwrap(),
967            ),
968            _phantom: std::marker::PhantomData,
969        }
970    }
971}
972
973impl<A: Float + ScalarOperand + Debug + Send + Sync + 'static, D: Dimension + 'static>
974    OptimizerTrait<A, D> for AdamWOptimizerWrapper<A, D>
975{
976    fn name(&self) -> &str {
977        "AdamW"
978    }
979
980    fn step(&mut self, params: &mut [Array<A, D>], grads: &[Array<A, D>]) -> Result<()> {
981        if params.len() != grads.len() {
982            return Err(crate::error::OptimError::InvalidParameter(
983                "Mismatched number of parameters and gradients".to_string(),
984            ));
985        }
986
987        for (param, grad) in params.iter_mut().zip(grads.iter()) {
988            let updated = self.inner.step(param, grad)?;
989            *param = updated;
990        }
991        Ok(())
992    }
993
994    fn learning_rate(&self) -> A {
995        self.inner.learning_rate()
996    }
997
998    fn set_learning_rate(&mut self, lr: A) {
999        <crate::optimizers::AdamW<A> as crate::optimizers::Optimizer<A, D>>::set_learning_rate(
1000            &mut self.inner,
1001            lr,
1002        );
1003    }
1004
1005    fn get_state(&self) -> HashMap<String, Vec<u8>> {
1006        HashMap::new()
1007    }
1008
1009    fn set_state(&mut self, state: HashMap<String, Vec<u8>>) -> Result<()> {
1010        Ok(())
1011    }
1012
1013    fn clone_optimizer(&self) -> Box<dyn OptimizerTrait<A, D>> {
1014        Box::new(AdamWOptimizerWrapper {
1015            inner: self.inner.clone(),
1016            _phantom: std::marker::PhantomData,
1017        })
1018    }
1019}
1020
1021#[cfg(test)]
1022mod tests {
1023    use super::*;
1024    use scirs2_core::ndarray::Array1;
1025    use std::time::Duration;
1026
1027    #[test]
1028    fn test_self_tuning_config_default() {
1029        let config = SelfTuningConfig::default();
1030        assert_eq!(config.evaluation_window, 100);
1031        assert!(config.auto_lr_adjustment);
1032        assert!(config.auto_optimizer_selection);
1033    }
1034
1035    #[test]
1036    fn test_self_tuning_optimizer_creation() {
1037        let config = SelfTuningConfig::default();
1038        let optimizer: Result<SelfTuningOptimizer<f64, scirs2_core::ndarray::Ix1>> =
1039            SelfTuningOptimizer::new(config);
1040        assert!(optimizer.is_ok());
1041    }
1042
1043    #[test]
1044    fn test_performance_stats() {
1045        let stats = PerformanceStats {
1046            loss: 0.5,
1047            accuracy: Some(0.9),
1048            gradient_norm: 1.2,
1049            throughput: 100.0,
1050            memory_usage: 1024.0,
1051            step_time: Duration::from_millis(50),
1052            learning_rate: 0.001,
1053            optimizer_type: "Adam".to_string(),
1054            custom_metrics: HashMap::new(),
1055        };
1056
1057        assert_eq!(stats.loss, 0.5);
1058        assert_eq!(stats.accuracy, Some(0.9));
1059    }
1060
1061    #[test]
1062    fn test_optimizer_step() {
1063        let config = SelfTuningConfig::default();
1064        let mut optimizer: SelfTuningOptimizer<f64, scirs2_core::ndarray::Ix1> =
1065            SelfTuningOptimizer::new(config).unwrap();
1066
1067        let mut params = vec![Array1::zeros(10)];
1068        let grads = vec![Array1::ones(10)];
1069
1070        let stats = PerformanceStats {
1071            loss: 1.0,
1072            accuracy: None,
1073            gradient_norm: 1.0,
1074            throughput: 50.0,
1075            memory_usage: 512.0,
1076            step_time: Duration::from_millis(10),
1077            learning_rate: 0.001,
1078            optimizer_type: "Adam".to_string(),
1079            custom_metrics: HashMap::new(),
1080        };
1081
1082        let result = optimizer.step(&mut params, &grads, stats);
1083        assert!(result.is_ok());
1084
1085        let info = optimizer.get_optimizer_info();
1086        assert_eq!(info.name, "Adam");
1087        assert_eq!(info.step_count, 1);
1088    }
1089
1090    #[test]
1091    fn test_bandit_selection() {
1092        let config = SelfTuningConfig::default();
1093        let optimizer: SelfTuningOptimizer<f64, scirs2_core::ndarray::Ix1> =
1094            SelfTuningOptimizer::new(config).unwrap();
1095
1096        let selection = optimizer.select_ucb1();
1097        assert!(selection < optimizer.optimizer_candidates.len());
1098    }
1099
1100    #[test]
1101    fn test_performance_metric_extraction() {
1102        let config = SelfTuningConfig {
1103            target_metric: TargetMetric::Loss,
1104            ..Default::default()
1105        };
1106        let optimizer: SelfTuningOptimizer<f64, scirs2_core::ndarray::Ix1> =
1107            SelfTuningOptimizer::new(config).unwrap();
1108
1109        let stats = PerformanceStats {
1110            loss: 0.8,
1111            accuracy: Some(0.85),
1112            gradient_norm: 1.1,
1113            throughput: 75.0,
1114            memory_usage: 800.0,
1115            step_time: Duration::from_millis(20),
1116            learning_rate: 0.001,
1117            optimizer_type: "Adam".to_string(),
1118            custom_metrics: HashMap::new(),
1119        };
1120
1121        let metric = optimizer.extract_performance_metric(&stats);
1122        assert_eq!(metric, Some(0.8));
1123    }
1124
1125    #[test]
1126    fn test_statistics() {
1127        let config = SelfTuningConfig::default();
1128        let optimizer: SelfTuningOptimizer<f64, scirs2_core::ndarray::Ix1> =
1129            SelfTuningOptimizer::new(config).unwrap();
1130
1131        let stats = optimizer.get_statistics();
1132        assert_eq!(stats.total_steps, 0);
1133        assert!(stats.optimizer_usage.contains_key("Adam"));
1134    }
1135}