Skip to main content

optirs_core/
self_tuning.rs

1// Self-tuning optimization strategies
2//
3// This module provides adaptive optimization strategies that automatically
4// tune hyperparameters, select optimizers, and adjust configurations based
5// on training dynamics and problem characteristics.
6
7#[allow(dead_code)]
8use crate::error::Result;
9use crate::optimizers::*;
10use crate::schedulers::*;
11use scirs2_core::ndarray::{Array, Dimension, ScalarOperand};
12use scirs2_core::numeric::Float;
13use scirs2_core::random::{thread_rng, Rng};
14use std::collections::{HashMap, VecDeque};
15use std::fmt::Debug;
16use std::time::{Duration, Instant};
17
18/// Configuration for self-tuning optimization
19#[derive(Debug, Clone)]
20pub struct SelfTuningConfig {
21    /// Window size for performance evaluation
22    pub evaluation_window: usize,
23
24    /// Minimum improvement threshold for parameter updates
25    pub improvement_threshold: f64,
26
27    /// Maximum number of optimizer switches per epoch
28    pub max_switches_per_epoch: usize,
29
30    /// Enable automatic learning rate adjustment
31    pub auto_lr_adjustment: bool,
32
33    /// Enable automatic optimizer selection
34    pub auto_optimizer_selection: bool,
35
36    /// Enable automatic batch size tuning
37    pub auto_batch_size_tuning: bool,
38
39    /// Warmup period before starting adaptations
40    pub warmup_steps: usize,
41
42    /// Exploration probability for optimizer selection
43    pub exploration_rate: f64,
44
45    /// Decay rate for exploration
46    pub exploration_decay: f64,
47
48    /// Performance metric to optimize
49    pub target_metric: TargetMetric,
50}
51
52impl Default for SelfTuningConfig {
53    fn default() -> Self {
54        Self {
55            evaluation_window: 100,
56            improvement_threshold: 0.01,
57            max_switches_per_epoch: 3,
58            auto_lr_adjustment: true,
59            auto_optimizer_selection: true,
60            auto_batch_size_tuning: false,
61            warmup_steps: 1000,
62            exploration_rate: 0.1,
63            exploration_decay: 0.99,
64            target_metric: TargetMetric::Loss,
65        }
66    }
67}
68
69/// Target optimization metric
70#[derive(Debug, Clone, Copy, PartialEq)]
71pub enum TargetMetric {
72    /// Minimize loss
73    Loss,
74    /// Maximize accuracy
75    Accuracy,
76    /// Minimize convergence time
77    ConvergenceTime,
78    /// Maximize training throughput
79    Throughput,
80    /// Custom metric (user-defined)
81    Custom,
82}
83
84/// Performance statistics for tracking optimization progress
85#[derive(Debug, Clone)]
86pub struct PerformanceStats {
87    /// Current loss value
88    pub loss: f64,
89
90    /// Current accuracy (if available)
91    pub accuracy: Option<f64>,
92
93    /// Gradient norm
94    pub gradient_norm: f64,
95
96    /// Training throughput (samples/second)
97    pub throughput: f64,
98
99    /// Memory usage (MB)
100    pub memory_usage: f64,
101
102    /// Wall clock time for this step
103    pub step_time: Duration,
104
105    /// Learning rate used
106    pub learning_rate: f64,
107
108    /// Optimizer type used
109    pub optimizer_type: String,
110
111    /// Custom metrics
112    pub custom_metrics: HashMap<String, f64>,
113}
114
115/// Adaptive optimizer that automatically tunes hyperparameters
116pub struct SelfTuningOptimizer<A: Float, D: Dimension> {
117    /// Configuration
118    config: SelfTuningConfig,
119
120    /// Current active optimizer
121    current_optimizer: Box<dyn OptimizerTrait<A, D>>,
122
123    /// Available optimizer candidates
124    optimizer_candidates: Vec<OptimizerCandidate<A, D>>,
125
126    /// Performance history
127    performance_history: VecDeque<PerformanceStats>,
128
129    /// Hyperparameter search state
130    search_state: HyperparameterSearchState,
131
132    /// Learning rate scheduler
133    lr_scheduler: Option<Box<dyn LearningRateScheduler<A>>>,
134
135    /// Optimizer selection strategy
136    selection_strategy: OptimizerSelectionStrategy,
137
138    /// Current step count
139    step_count: usize,
140
141    /// Number of optimizer switches in current epoch
142    switches_this_epoch: usize,
143
144    /// Best performance seen so far
145    best_performance: Option<f64>,
146
147    /// Time of last adaptation
148    last_adaptation_time: Instant,
149
150    /// Exploration state for multi-armed bandit
151    bandit_state: BanditState,
152}
153
154/// Optimizer candidate with its configuration
155struct OptimizerCandidate<A: Float, D: Dimension> {
156    /// Name/identifier
157    name: String,
158
159    /// Factory function to create the optimizer
160    factory: Box<dyn Fn() -> Box<dyn OptimizerTrait<A, D>>>,
161
162    /// Performance history for this optimizer
163    performance_history: Vec<f64>,
164
165    /// Usage count
166    usage_count: usize,
167
168    /// Average performance
169    average_performance: f64,
170
171    /// Confidence interval
172    confidence_interval: (f64, f64),
173}
174
175/// Hyperparameter search state
176#[derive(Debug)]
177struct HyperparameterSearchState {
178    /// Current learning rate
179    learning_rate: f64,
180
181    /// Learning rate search bounds
182    lr_bounds: (f64, f64),
183
184    /// Current batch size
185    batch_size: usize,
186
187    /// Batch size search bounds
188    batch_size_bounds: (usize, usize),
189
190    /// Number of search iterations performed
191    search_iterations: usize,
192
193    /// Best hyperparameters found
194    best_hyperparameters: HashMap<String, f64>,
195
196    /// Search algorithm state
197    search_algorithm: SearchAlgorithm,
198}
199
200/// Hyperparameter search algorithms
201#[derive(Debug)]
202enum SearchAlgorithm {
203    /// Random search
204    Random {
205        /// Random number generator seed
206        seed: u64,
207    },
208
209    /// Bayesian optimization
210    Bayesian {
211        /// Gaussian process state
212        gp_state: GaussianProcessState,
213    },
214
215    /// Grid search
216    Grid {
217        /// Current grid position
218        position: Vec<usize>,
219        /// Grid dimensions
220        dimensions: Vec<usize>,
221    },
222
223    /// Successive halving (Hyperband)
224    SuccessiveHalving {
225        /// Current bracket
226        bracket: usize,
227        /// Configurations in current round
228        configurations: Vec<HashMap<String, f64>>,
229    },
230}
231
232/// Gaussian process state for Bayesian optimization
233#[derive(Debug)]
234struct GaussianProcessState {
235    /// Observed points
236    observed_points: Vec<Vec<f64>>,
237
238    /// Observed values
239    observed_values: Vec<f64>,
240
241    /// Kernel hyperparameters
242    kernel_params: Vec<f64>,
243
244    /// Acquisition function type
245    acquisition_function: AcquisitionFunction,
246}
247
248/// Acquisition functions for Bayesian optimization
249#[derive(Debug, Clone, Copy)]
250enum AcquisitionFunction {
251    ExpectedImprovement,
252    ProbabilityOfImprovement,
253    UpperConfidenceBound,
254}
255
256/// Optimizer selection strategies
257#[derive(Debug, Clone)]
258enum OptimizerSelectionStrategy {
259    /// Multi-armed bandit approach
260    MultiArmedBandit {
261        /// Bandit algorithm type
262        algorithm: BanditAlgorithm,
263    },
264
265    /// Performance-based selection
266    PerformanceBased {
267        /// Minimum performance difference for switching
268        min_difference: f64,
269    },
270
271    /// Round-robin testing
272    RoundRobin {
273        /// Current optimizer index
274        current_index: usize,
275    },
276
277    /// Meta-learning based selection
278    MetaLearning {
279        /// Problem characteristics
280        problem_features: Vec<f64>,
281        /// Learned optimizer mappings
282        optimizer_mappings: HashMap<String, f64>,
283    },
284}
285
286/// Multi-armed bandit algorithms
287#[derive(Debug, Clone, Copy)]
288enum BanditAlgorithm {
289    EpsilonGreedy,
290    UCB1,
291    ThompsonSampling,
292    LinUCB,
293}
294
295/// Multi-armed bandit state
296#[derive(Debug)]
297struct BanditState {
298    /// Reward estimates for each optimizer
299    reward_estimates: Vec<f64>,
300
301    /// Confidence bounds
302    confidence_bounds: Vec<f64>,
303
304    /// Selection counts
305    selection_counts: Vec<usize>,
306
307    /// Total selections
308    total_selections: usize,
309
310    /// Exploration parameter
311    exploration_param: f64,
312}
313
314/// Trait for optimizer implementations that can be used with self-tuning
315pub trait OptimizerTrait<A: Float + ScalarOperand + Debug, D: Dimension>: Send + Sync {
316    /// Get optimizer name
317    fn name(&self) -> &str;
318
319    /// Perform optimization step
320    fn step(&mut self, params: &mut [Array<A, D>], grads: &[Array<A, D>]) -> Result<()>;
321
322    /// Get current learning rate
323    fn learning_rate(&self) -> A;
324
325    /// Set learning rate
326    fn set_learning_rate(&mut self, lr: A);
327
328    /// Get optimizer state for serialization
329    fn get_state(&self) -> HashMap<String, Vec<u8>>;
330
331    /// Set optimizer state from serialization
332    fn set_state(&mut self, state: HashMap<String, Vec<u8>>) -> Result<()>;
333
334    /// Clone the optimizer
335    fn clone_optimizer(&self) -> Box<dyn OptimizerTrait<A, D>>;
336}
337
338impl<
339        A: Float + ScalarOperand + Debug + Send + Sync + 'static + scirs2_core::numeric::FromPrimitive,
340        D: Dimension + 'static,
341    > SelfTuningOptimizer<A, D>
342{
343    /// Create new self-tuning optimizer
344    pub fn new(config: SelfTuningConfig) -> Result<Self> {
345        let mut optimizer_candidates = Vec::new();
346
347        // Add default optimizer candidates
348        optimizer_candidates.push(OptimizerCandidate {
349            name: "Adam".to_string(),
350            factory: Box::new(|| Box::new(AdamOptimizerWrapper::new(0.001, 0.9, 0.999, 1e-8, 0.0))),
351            performance_history: Vec::new(),
352            usage_count: 0,
353            average_performance: 0.0,
354            confidence_interval: (0.0, 0.0),
355        });
356
357        optimizer_candidates.push(OptimizerCandidate {
358            name: "SGD".to_string(),
359            factory: Box::new(|| Box::new(SGDOptimizerWrapper::new(0.01, 0.9, 0.0, false))),
360            performance_history: Vec::new(),
361            usage_count: 0,
362            average_performance: 0.0,
363            confidence_interval: (0.0, 0.0),
364        });
365
366        optimizer_candidates.push(OptimizerCandidate {
367            name: "AdamW".to_string(),
368            factory: Box::new(|| {
369                Box::new(AdamWOptimizerWrapper::new(0.001, 0.9, 0.999, 1e-8, 0.01))
370            }),
371            performance_history: Vec::new(),
372            usage_count: 0,
373            average_performance: 0.0,
374            confidence_interval: (0.0, 0.0),
375        });
376
377        // Start with Adam as default
378        let current_optimizer = (optimizer_candidates[0].factory)();
379
380        let search_state = HyperparameterSearchState {
381            learning_rate: 0.001,
382            lr_bounds: (1e-6, 1.0),
383            batch_size: 32,
384            batch_size_bounds: (8, 512),
385            search_iterations: 0,
386            best_hyperparameters: HashMap::new(),
387            search_algorithm: SearchAlgorithm::Random { seed: 42 },
388        };
389
390        let selection_strategy = OptimizerSelectionStrategy::MultiArmedBandit {
391            algorithm: BanditAlgorithm::UCB1,
392        };
393
394        let bandit_state = BanditState {
395            reward_estimates: vec![0.0; optimizer_candidates.len()],
396            confidence_bounds: vec![1.0; optimizer_candidates.len()],
397            selection_counts: vec![0; optimizer_candidates.len()],
398            total_selections: 0,
399            exploration_param: 2.0,
400        };
401
402        Ok(Self {
403            config,
404            current_optimizer,
405            optimizer_candidates,
406            performance_history: VecDeque::new(),
407            search_state,
408            lr_scheduler: None,
409            selection_strategy,
410            step_count: 0,
411            switches_this_epoch: 0,
412            best_performance: None,
413            last_adaptation_time: Instant::now(),
414            bandit_state,
415        })
416    }
417
418    /// Add a custom optimizer candidate
419    pub fn add_optimizer_candidate<F>(&mut self, name: String, factory: F)
420    where
421        F: Fn() -> Box<dyn OptimizerTrait<A, D>> + 'static,
422    {
423        self.optimizer_candidates.push(OptimizerCandidate {
424            name,
425            factory: Box::new(factory),
426            performance_history: Vec::new(),
427            usage_count: 0,
428            average_performance: 0.0,
429            confidence_interval: (0.0, 0.0),
430        });
431
432        // Update bandit state
433        self.bandit_state.reward_estimates.push(0.0);
434        self.bandit_state.confidence_bounds.push(1.0);
435        self.bandit_state.selection_counts.push(0);
436    }
437
438    /// Perform optimization step with automatic tuning
439    pub fn step(
440        &mut self,
441        params: &mut [Array<A, D>],
442        grads: &[Array<A, D>],
443        stats: PerformanceStats,
444    ) -> Result<()> {
445        self.step_count += 1;
446
447        // Record performance
448        self.performance_history.push_back(stats.clone());
449        if self.performance_history.len() > self.config.evaluation_window {
450            self.performance_history.pop_front();
451        }
452
453        // Perform optimization step
454        self.current_optimizer.step(params, grads)?;
455
456        // Self-tuning adaptations
457        if self.step_count > self.config.warmup_steps {
458            self.maybe_adapt_optimizer(&stats)?;
459            self.maybe_adapt_learning_rate(&stats)?;
460            self.maybe_adapt_hyperparameters(&stats)?;
461        }
462
463        // Update best performance
464        let current_performance = self.extract_performance_metric(&stats);
465        if let Some(performance) = current_performance {
466            if self.best_performance.is_none()
467                || self.is_better_performance(
468                    performance,
469                    self.best_performance.expect("unwrap failed"),
470                )
471            {
472                self.best_performance = Some(performance);
473            }
474        }
475
476        Ok(())
477    }
478
479    /// Check if we should adapt the optimizer
480    fn maybe_adapt_optimizer(&mut self, stats: &PerformanceStats) -> Result<()> {
481        if !self.config.auto_optimizer_selection {
482            return Ok(());
483        }
484
485        if self.switches_this_epoch >= self.config.max_switches_per_epoch {
486            return Ok(());
487        }
488
489        let should_adapt = self.should_adapt_optimizer(stats);
490
491        if should_adapt {
492            self.adapt_optimizer(stats)?;
493            self.switches_this_epoch += 1;
494        }
495
496        Ok(())
497    }
498
499    /// Determine if optimizer should be adapted
500    fn should_adapt_optimizer(&self, stats: &PerformanceStats) -> bool {
501        if self.performance_history.len() < self.config.evaluation_window / 2 {
502            return false;
503        }
504
505        // Check for performance degradation or stagnation
506        let recent_performance: Vec<f64> = self
507            .performance_history
508            .iter()
509            .rev()
510            .take(self.config.evaluation_window / 4)
511            .filter_map(|s| self.extract_performance_metric(s))
512            .collect();
513
514        let older_performance: Vec<f64> = self
515            .performance_history
516            .iter()
517            .rev()
518            .skip(self.config.evaluation_window / 4)
519            .take(self.config.evaluation_window / 4)
520            .filter_map(|s| self.extract_performance_metric(s))
521            .collect();
522
523        if recent_performance.is_empty() || older_performance.is_empty() {
524            return false;
525        }
526
527        let recent_avg = recent_performance.iter().sum::<f64>() / recent_performance.len() as f64;
528        let older_avg = older_performance.iter().sum::<f64>() / older_performance.len() as f64;
529
530        // Check for stagnation or degradation
531        match self.config.target_metric {
532            TargetMetric::Loss => {
533                (recent_avg - older_avg).abs() < self.config.improvement_threshold
534                    || recent_avg > older_avg
535            }
536            TargetMetric::Accuracy | TargetMetric::Throughput => {
537                (recent_avg - older_avg).abs() < self.config.improvement_threshold
538                    || recent_avg < older_avg
539            }
540            _ => false,
541        }
542    }
543
544    /// Adapt the optimizer based on performance
545    fn adapt_optimizer(&mut self, stats: &PerformanceStats) -> Result<()> {
546        let new_optimizer_idx = match &self.selection_strategy {
547            OptimizerSelectionStrategy::MultiArmedBandit { algorithm } => {
548                self.select_optimizer_bandit(*algorithm)
549            }
550            OptimizerSelectionStrategy::PerformanceBased { .. } => {
551                self.select_optimizer_performance_based()
552            }
553            OptimizerSelectionStrategy::RoundRobin { current_index } => {
554                (*current_index + 1) % self.optimizer_candidates.len()
555            }
556            OptimizerSelectionStrategy::MetaLearning { .. } => {
557                self.select_optimizer_meta_learning(stats)
558            }
559        };
560
561        // Switch to new optimizer
562        if new_optimizer_idx < self.optimizer_candidates.len() {
563            let current_lr = self.current_optimizer.learning_rate();
564            let current_state = self.current_optimizer.get_state();
565
566            self.current_optimizer = (self.optimizer_candidates[new_optimizer_idx].factory)();
567            self.current_optimizer.set_learning_rate(current_lr);
568
569            // Try to transfer compatible state
570            if self.current_optimizer.set_state(current_state).is_err() {
571                // State transfer failed, continue with fresh state
572            }
573
574            // Update usage statistics
575            self.optimizer_candidates[new_optimizer_idx].usage_count += 1;
576        }
577
578        Ok(())
579    }
580
581    /// Select optimizer using multi-armed bandit
582    fn select_optimizer_bandit(&mut self, algorithm: BanditAlgorithm) -> usize {
583        match algorithm {
584            BanditAlgorithm::UCB1 => self.select_ucb1(),
585            BanditAlgorithm::EpsilonGreedy => self.select_epsilon_greedy(),
586            BanditAlgorithm::ThompsonSampling => self.select_thompson_sampling(),
587            BanditAlgorithm::LinUCB => self.select_linucb(),
588        }
589    }
590
591    /// UCB1 optimizer selection
592    fn select_ucb1(&self) -> usize {
593        if self.bandit_state.total_selections == 0 {
594            return 0;
595        }
596
597        let mut best_score = f64::NEG_INFINITY;
598        let mut best_idx = 0;
599
600        for (i, candidate) in self.optimizer_candidates.iter().enumerate() {
601            let ucb_score = if self.bandit_state.selection_counts[i] == 0 {
602                f64::INFINITY
603            } else {
604                let mean_reward = self.bandit_state.reward_estimates[i];
605                let confidence = (self.bandit_state.exploration_param
606                    * (self.bandit_state.total_selections as f64).ln()
607                    / self.bandit_state.selection_counts[i] as f64)
608                    .sqrt();
609                mean_reward + confidence
610            };
611
612            if ucb_score > best_score {
613                best_score = ucb_score;
614                best_idx = i;
615            }
616        }
617
618        best_idx
619    }
620
621    /// Epsilon-greedy optimizer selection
622    fn select_epsilon_greedy(&self) -> usize {
623        let mut rng = thread_rng();
624
625        if A::from(rng.random::<f64>()).expect("unwrap failed")
626            < A::from(self.config.exploration_rate).expect("unwrap failed")
627        {
628            // Explore: random selection
629            rng.gen_range(0..self.optimizer_candidates.len())
630        } else {
631            // Exploit: best performing optimizer
632            self.bandit_state
633                .reward_estimates
634                .iter()
635                .enumerate()
636                .max_by(|a, b| a.1.partial_cmp(b.1).expect("unwrap failed"))
637                .map(|(idx, _)| idx)
638                .unwrap_or(0)
639        }
640    }
641
642    /// Thompson sampling optimizer selection
643    fn select_thompson_sampling(&self) -> usize {
644        // Simplified Thompson sampling - in practice would use Beta distributions
645        let mut rng = thread_rng();
646
647        let mut best_sample = f64::NEG_INFINITY;
648        let mut best_idx = 0;
649
650        for (i, _) in self.optimizer_candidates.iter().enumerate() {
651            let mean = self.bandit_state.reward_estimates[i];
652            let std = self.bandit_state.confidence_bounds[i];
653            let sample = rng.gen_range(mean - std..mean + std);
654
655            if sample > best_sample {
656                best_sample = sample;
657                best_idx = i;
658            }
659        }
660
661        best_idx
662    }
663
664    /// LinUCB optimizer selection (contextual bandit)
665    fn select_linucb(&self) -> usize {
666        // Simplified LinUCB - would use contextual features in practice
667        self.select_ucb1()
668    }
669
670    /// Performance-based optimizer selection
671    fn select_optimizer_performance_based(&self) -> usize {
672        self.optimizer_candidates
673            .iter()
674            .enumerate()
675            .max_by(|a, b| {
676                a.1.average_performance
677                    .partial_cmp(&b.1.average_performance)
678                    .expect("unwrap failed")
679            })
680            .map(|(idx, _)| idx)
681            .unwrap_or(0)
682    }
683
684    /// Meta-learning based optimizer selection
685    fn select_optimizer_meta_learning(&self, stats: &PerformanceStats) -> usize {
686        // Simplified meta-learning - would use problem features in practice
687        0
688    }
689
690    /// Adapt learning rate based on performance
691    fn maybe_adapt_learning_rate(&mut self, stats: &PerformanceStats) -> Result<()> {
692        if !self.config.auto_lr_adjustment {
693            return Ok(());
694        }
695
696        // Simple adaptive learning rate based on gradient norm
697        let current_lr = self
698            .current_optimizer
699            .learning_rate()
700            .to_f64()
701            .expect("unwrap failed");
702        let gradient_norm = stats.gradient_norm;
703
704        let new_lr = if gradient_norm > 10.0 {
705            // Large gradients - reduce learning rate
706            current_lr * 0.9
707        } else if gradient_norm < 0.1 {
708            // Small gradients - increase learning rate
709            current_lr * 1.1
710        } else {
711            current_lr
712        };
713
714        let clamped_lr = new_lr
715            .max(self.search_state.lr_bounds.0)
716            .min(self.search_state.lr_bounds.1);
717
718        if (clamped_lr - current_lr).abs() > current_lr * 0.01 {
719            self.current_optimizer
720                .set_learning_rate(A::from(clamped_lr).expect("unwrap failed"));
721            self.search_state.learning_rate = clamped_lr;
722        }
723
724        Ok(())
725    }
726
727    /// Adapt other hyperparameters
728    fn maybe_adapt_hyperparameters(&mut self, stats: &PerformanceStats) -> Result<()> {
729        // Placeholder for hyperparameter adaptation
730        // Would implement Bayesian optimization, random search, etc.
731        Ok(())
732    }
733
734    /// Extract performance metric from stats
735    fn extract_performance_metric(&self, stats: &PerformanceStats) -> Option<f64> {
736        match self.config.target_metric {
737            TargetMetric::Loss => Some(stats.loss),
738            TargetMetric::Accuracy => stats.accuracy,
739            TargetMetric::Throughput => Some(stats.throughput),
740            TargetMetric::ConvergenceTime => Some(stats.step_time.as_secs_f64()),
741            TargetMetric::Custom => stats.custom_metrics.values().next().copied(),
742        }
743    }
744
745    /// Check if performance is better
746    fn is_better_performance(&self, new_perf: f64, oldperf: f64) -> bool {
747        match self.config.target_metric {
748            TargetMetric::Loss | TargetMetric::ConvergenceTime => new_perf < oldperf,
749            TargetMetric::Accuracy | TargetMetric::Throughput => new_perf > oldperf,
750            TargetMetric::Custom => new_perf > oldperf, // Assume higher is better for custom
751        }
752    }
753
754    /// Reset epoch counters
755    pub fn reset_epoch(&mut self) {
756        self.switches_this_epoch = 0;
757    }
758
759    /// Get current optimizer information
760    pub fn get_optimizer_info(&self) -> OptimizerInfo {
761        OptimizerInfo {
762            name: self.current_optimizer.name().to_string(),
763            learning_rate: self
764                .current_optimizer
765                .learning_rate()
766                .to_f64()
767                .expect("unwrap failed"),
768            step_count: self.step_count,
769            switches_this_epoch: self.switches_this_epoch,
770            performance_window_size: self.performance_history.len(),
771            best_performance: self.best_performance,
772        }
773    }
774
775    /// Get optimization statistics
776    pub fn get_statistics(&self) -> SelfTuningStatistics {
777        let optimizer_usage: HashMap<String, usize> = self
778            .optimizer_candidates
779            .iter()
780            .map(|c| (c.name.clone(), c.usage_count))
781            .collect();
782
783        SelfTuningStatistics {
784            total_steps: self.step_count,
785            total_optimizer_switches: self
786                .optimizer_candidates
787                .iter()
788                .map(|c| c.usage_count)
789                .sum(),
790            optimizer_usage,
791            current_learning_rate: self.search_state.learning_rate,
792            average_step_time: self
793                .performance_history
794                .iter()
795                .map(|s| s.step_time.as_secs_f64())
796                .sum::<f64>()
797                / self.performance_history.len().max(1) as f64,
798            exploration_rate: self.config.exploration_rate,
799        }
800    }
801}
802
803/// Information about current optimizer state
804#[derive(Debug, Clone)]
805pub struct OptimizerInfo {
806    pub name: String,
807    pub learning_rate: f64,
808    pub step_count: usize,
809    pub switches_this_epoch: usize,
810    pub performance_window_size: usize,
811    pub best_performance: Option<f64>,
812}
813
814/// Statistics about self-tuning optimization
815#[derive(Debug, Clone)]
816pub struct SelfTuningStatistics {
817    pub total_steps: usize,
818    pub total_optimizer_switches: usize,
819    pub optimizer_usage: HashMap<String, usize>,
820    pub current_learning_rate: f64,
821    pub average_step_time: f64,
822    pub exploration_rate: f64,
823}
824
825// Wrapper implementations for existing optimizers
826struct AdamOptimizerWrapper<A: Float + ScalarOperand + Debug, D: Dimension> {
827    inner: crate::optimizers::Adam<A>,
828    _phantom: std::marker::PhantomData<D>,
829}
830
831impl<A: Float + ScalarOperand + Debug + Send + Sync, D: Dimension + Send + Sync>
832    AdamOptimizerWrapper<A, D>
833{
834    fn new(_lr: f64, beta1: f64, beta2: f64, eps: f64, weightdecay: f64) -> Self {
835        Self {
836            inner: crate::optimizers::Adam::new_with_config(
837                A::from(_lr).expect("unwrap failed"),
838                A::from(beta1).expect("unwrap failed"),
839                A::from(beta2).expect("unwrap failed"),
840                A::from(eps).expect("unwrap failed"),
841                A::from(weightdecay).expect("unwrap failed"),
842            ),
843            _phantom: std::marker::PhantomData,
844        }
845    }
846}
847
848impl<A: Float + ScalarOperand + Debug + Send + Sync + 'static, D: Dimension + 'static>
849    OptimizerTrait<A, D> for AdamOptimizerWrapper<A, D>
850{
851    fn name(&self) -> &str {
852        "Adam"
853    }
854
855    fn step(&mut self, params: &mut [Array<A, D>], grads: &[Array<A, D>]) -> Result<()> {
856        if params.len() != grads.len() {
857            return Err(crate::error::OptimError::InvalidParameter(
858                "Mismatched number of parameters and gradients".to_string(),
859            ));
860        }
861
862        for (param, grad) in params.iter_mut().zip(grads.iter()) {
863            let updated = self.inner.step(param, grad)?;
864            *param = updated;
865        }
866        Ok(())
867    }
868
869    fn learning_rate(&self) -> A {
870        self.inner.learning_rate()
871    }
872
873    fn set_learning_rate(&mut self, lr: A) {
874        <crate::optimizers::Adam<A> as crate::optimizers::Optimizer<A, D>>::set_learning_rate(
875            &mut self.inner,
876            lr,
877        );
878    }
879
880    fn get_state(&self) -> HashMap<String, Vec<u8>> {
881        HashMap::new()
882    }
883
884    fn set_state(&mut self, state: HashMap<String, Vec<u8>>) -> Result<()> {
885        Ok(())
886    }
887
888    fn clone_optimizer(&self) -> Box<dyn OptimizerTrait<A, D>> {
889        Box::new(AdamOptimizerWrapper {
890            inner: self.inner.clone(),
891            _phantom: std::marker::PhantomData,
892        })
893    }
894}
895
896struct SGDOptimizerWrapper<A: Float + ScalarOperand + Debug, D: Dimension> {
897    inner: crate::optimizers::SGD<A>,
898    _phantom: std::marker::PhantomData<D>,
899}
900
901impl<A: Float + ScalarOperand + Debug + Send + Sync, D: Dimension + Send + Sync>
902    SGDOptimizerWrapper<A, D>
903{
904    fn new(_lr: f64, momentum: f64, weightdecay: f64, nesterov: bool) -> Self {
905        Self {
906            inner: crate::optimizers::SGD::new_with_config(
907                A::from(_lr).expect("unwrap failed"),
908                A::from(momentum).expect("unwrap failed"),
909                A::from(weightdecay).expect("unwrap failed"),
910            ),
911            _phantom: std::marker::PhantomData,
912        }
913    }
914}
915
916impl<A: Float + ScalarOperand + Debug + Send + Sync + 'static, D: Dimension + 'static>
917    OptimizerTrait<A, D> for SGDOptimizerWrapper<A, D>
918{
919    fn name(&self) -> &str {
920        "SGD"
921    }
922
923    fn step(&mut self, params: &mut [Array<A, D>], grads: &[Array<A, D>]) -> Result<()> {
924        if params.len() != grads.len() {
925            return Err(crate::error::OptimError::InvalidParameter(
926                "Mismatched number of parameters and gradients".to_string(),
927            ));
928        }
929
930        for (param, grad) in params.iter_mut().zip(grads.iter()) {
931            let updated = self.inner.step(param, grad)?;
932            *param = updated;
933        }
934        Ok(())
935    }
936
937    fn learning_rate(&self) -> A {
938        self.inner.learning_rate()
939    }
940
941    fn set_learning_rate(&mut self, lr: A) {
942        <crate::optimizers::SGD<A> as crate::optimizers::Optimizer<A, D>>::set_learning_rate(
943            &mut self.inner,
944            lr,
945        );
946    }
947
948    fn get_state(&self) -> HashMap<String, Vec<u8>> {
949        HashMap::new()
950    }
951
952    fn set_state(&mut self, state: HashMap<String, Vec<u8>>) -> Result<()> {
953        Ok(())
954    }
955
956    fn clone_optimizer(&self) -> Box<dyn OptimizerTrait<A, D>> {
957        Box::new(SGDOptimizerWrapper {
958            inner: self.inner.clone(),
959            _phantom: std::marker::PhantomData,
960        })
961    }
962}
963
964struct AdamWOptimizerWrapper<A: Float + ScalarOperand + Debug, D: Dimension> {
965    inner: crate::optimizers::AdamW<A>,
966    _phantom: std::marker::PhantomData<D>,
967}
968
969impl<A: Float + ScalarOperand + Debug + Send + Sync, D: Dimension + Send + Sync>
970    AdamWOptimizerWrapper<A, D>
971{
972    fn new(_lr: f64, beta1: f64, beta2: f64, eps: f64, weightdecay: f64) -> Self {
973        Self {
974            inner: crate::optimizers::AdamW::new_with_config(
975                A::from(_lr).expect("unwrap failed"),
976                A::from(beta1).expect("unwrap failed"),
977                A::from(beta2).expect("unwrap failed"),
978                A::from(eps).expect("unwrap failed"),
979                A::from(weightdecay).expect("unwrap failed"),
980            ),
981            _phantom: std::marker::PhantomData,
982        }
983    }
984}
985
986impl<A: Float + ScalarOperand + Debug + Send + Sync + 'static, D: Dimension + 'static>
987    OptimizerTrait<A, D> for AdamWOptimizerWrapper<A, D>
988{
989    fn name(&self) -> &str {
990        "AdamW"
991    }
992
993    fn step(&mut self, params: &mut [Array<A, D>], grads: &[Array<A, D>]) -> Result<()> {
994        if params.len() != grads.len() {
995            return Err(crate::error::OptimError::InvalidParameter(
996                "Mismatched number of parameters and gradients".to_string(),
997            ));
998        }
999
1000        for (param, grad) in params.iter_mut().zip(grads.iter()) {
1001            let updated = self.inner.step(param, grad)?;
1002            *param = updated;
1003        }
1004        Ok(())
1005    }
1006
1007    fn learning_rate(&self) -> A {
1008        self.inner.learning_rate()
1009    }
1010
1011    fn set_learning_rate(&mut self, lr: A) {
1012        <crate::optimizers::AdamW<A> as crate::optimizers::Optimizer<A, D>>::set_learning_rate(
1013            &mut self.inner,
1014            lr,
1015        );
1016    }
1017
1018    fn get_state(&self) -> HashMap<String, Vec<u8>> {
1019        HashMap::new()
1020    }
1021
1022    fn set_state(&mut self, state: HashMap<String, Vec<u8>>) -> Result<()> {
1023        Ok(())
1024    }
1025
1026    fn clone_optimizer(&self) -> Box<dyn OptimizerTrait<A, D>> {
1027        Box::new(AdamWOptimizerWrapper {
1028            inner: self.inner.clone(),
1029            _phantom: std::marker::PhantomData,
1030        })
1031    }
1032}
1033
1034#[cfg(test)]
1035mod tests {
1036    use super::*;
1037    use scirs2_core::ndarray::Array1;
1038    use std::time::Duration;
1039
1040    #[test]
1041    fn test_self_tuning_config_default() {
1042        let config = SelfTuningConfig::default();
1043        assert_eq!(config.evaluation_window, 100);
1044        assert!(config.auto_lr_adjustment);
1045        assert!(config.auto_optimizer_selection);
1046    }
1047
1048    #[test]
1049    fn test_self_tuning_optimizer_creation() {
1050        let config = SelfTuningConfig::default();
1051        let optimizer: Result<SelfTuningOptimizer<f64, scirs2_core::ndarray::Ix1>> =
1052            SelfTuningOptimizer::new(config);
1053        assert!(optimizer.is_ok());
1054    }
1055
1056    #[test]
1057    fn test_performance_stats() {
1058        let stats = PerformanceStats {
1059            loss: 0.5,
1060            accuracy: Some(0.9),
1061            gradient_norm: 1.2,
1062            throughput: 100.0,
1063            memory_usage: 1024.0,
1064            step_time: Duration::from_millis(50),
1065            learning_rate: 0.001,
1066            optimizer_type: "Adam".to_string(),
1067            custom_metrics: HashMap::new(),
1068        };
1069
1070        assert_eq!(stats.loss, 0.5);
1071        assert_eq!(stats.accuracy, Some(0.9));
1072    }
1073
1074    #[test]
1075    fn test_optimizer_step() {
1076        let config = SelfTuningConfig::default();
1077        let mut optimizer: SelfTuningOptimizer<f64, scirs2_core::ndarray::Ix1> =
1078            SelfTuningOptimizer::new(config).expect("unwrap failed");
1079
1080        let mut params = vec![Array1::zeros(10)];
1081        let grads = vec![Array1::ones(10)];
1082
1083        let stats = PerformanceStats {
1084            loss: 1.0,
1085            accuracy: None,
1086            gradient_norm: 1.0,
1087            throughput: 50.0,
1088            memory_usage: 512.0,
1089            step_time: Duration::from_millis(10),
1090            learning_rate: 0.001,
1091            optimizer_type: "Adam".to_string(),
1092            custom_metrics: HashMap::new(),
1093        };
1094
1095        let result = optimizer.step(&mut params, &grads, stats);
1096        assert!(result.is_ok());
1097
1098        let info = optimizer.get_optimizer_info();
1099        assert_eq!(info.name, "Adam");
1100        assert_eq!(info.step_count, 1);
1101    }
1102
1103    #[test]
1104    fn test_bandit_selection() {
1105        let config = SelfTuningConfig::default();
1106        let optimizer: SelfTuningOptimizer<f64, scirs2_core::ndarray::Ix1> =
1107            SelfTuningOptimizer::new(config).expect("unwrap failed");
1108
1109        let selection = optimizer.select_ucb1();
1110        assert!(selection < optimizer.optimizer_candidates.len());
1111    }
1112
1113    #[test]
1114    fn test_performance_metric_extraction() {
1115        let config = SelfTuningConfig {
1116            target_metric: TargetMetric::Loss,
1117            ..Default::default()
1118        };
1119        let optimizer: SelfTuningOptimizer<f64, scirs2_core::ndarray::Ix1> =
1120            SelfTuningOptimizer::new(config).expect("unwrap failed");
1121
1122        let stats = PerformanceStats {
1123            loss: 0.8,
1124            accuracy: Some(0.85),
1125            gradient_norm: 1.1,
1126            throughput: 75.0,
1127            memory_usage: 800.0,
1128            step_time: Duration::from_millis(20),
1129            learning_rate: 0.001,
1130            optimizer_type: "Adam".to_string(),
1131            custom_metrics: HashMap::new(),
1132        };
1133
1134        let metric = optimizer.extract_performance_metric(&stats);
1135        assert_eq!(metric, Some(0.8));
1136    }
1137
1138    #[test]
1139    fn test_statistics() {
1140        let config = SelfTuningConfig::default();
1141        let optimizer: SelfTuningOptimizer<f64, scirs2_core::ndarray::Ix1> =
1142            SelfTuningOptimizer::new(config).expect("unwrap failed");
1143
1144        let stats = optimizer.get_statistics();
1145        assert_eq!(stats.total_steps, 0);
1146        assert!(stats.optimizer_usage.contains_key("Adam"));
1147    }
1148}