1use scirs2_core::ndarray::{Array1, Array2};
8use scirs2_core::numeric::Float;
9use std::collections::{HashMap, VecDeque};
10use std::time::{Duration, Instant};
11
12#[allow(unused_imports)]
13use crate::error::Result;
14
15#[derive(Debug, Clone)]
17pub enum PerformanceMetric<A: Float + Send + Sync> {
18 Loss(A),
19 Accuracy(A),
20 F1Score(A),
21 AUC(A),
22 Custom { name: String, value: A },
23}
24
25#[derive(Debug, Clone)]
27pub struct EnhancedAdaptiveLRController<A: Float + Send + Sync> {
28 current_lr: A,
30
31 base_lr: A,
33
34 min_lr: A,
36 max_lr: A,
37
38 adaptation_strategy: MultiSignalAdaptationStrategy<A>,
40
41 gradient_adapter: GradientBasedAdapter<A>,
43
44 performance_adapter: PerformanceBasedAdapter<A>,
46
47 drift_adapter: DriftAwareAdapter<A>,
49
50 resource_adapter: ResourceAwareAdapter<A>,
52
53 meta_optimizer: MetaOptimizer<A>,
55
56 adaptation_history: VecDeque<AdaptationEvent<A>>,
58
59 config: AdaptiveLRConfig<A>,
61}
62
63#[derive(Debug, Clone)]
65pub struct AdaptiveLRConfig<A: Float + Send + Sync> {
66 pub base_lr: A,
68
69 pub min_lr: A,
71
72 pub max_lr: A,
74
75 pub enable_gradient_adaptation: bool,
77
78 pub enable_performance_adaptation: bool,
80
81 pub enable_drift_adaptation: bool,
83
84 pub enable_resource_adaptation: bool,
86
87 pub enable_meta_learning: bool,
89
90 pub history_window_size: usize,
92
93 pub adaptation_frequency: usize,
95
96 pub adaptation_sensitivity: A,
98
99 pub use_ensemble_voting: bool,
101}
102
103#[derive(Debug, Clone)]
105pub struct MultiSignalAdaptationStrategy<A: Float + Send + Sync> {
106 signal_weights: HashMap<AdaptationSignalType, A>,
108
109 voting_history: VecDeque<SignalVote<A>>,
111
112 conflict_resolution: ConflictResolution,
114
115 signal_reliability: HashMap<AdaptationSignalType, A>,
117
118 last_decision: Option<AdaptationDecision<A>>,
120}
121
122#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
124pub enum AdaptationSignalType {
125 GradientMagnitude,
126 GradientVariance,
127 LossProgression,
128 AccuracyTrend,
129 ConceptDrift,
130 ResourceUtilization,
131 ModelComplexity,
132 DataQuality,
133}
134
135#[derive(Debug, Clone)]
137pub struct SignalVote<A: Float + Send + Sync> {
138 signal_type: AdaptationSignalType,
139 recommended_lr_change: A, confidence: A,
141 reasoning: String,
142 timestamp: Instant,
143}
144
145#[derive(Debug, Clone, Copy)]
147pub enum ConflictResolution {
148 WeightedAverage,
150 HighestConfidence,
152 MajorityVote { threshold: f64 },
154 Conservative,
156 MetaLearned,
158}
159
160#[derive(Debug, Clone)]
162pub struct AdaptationDecision<A: Float + Send + Sync> {
163 new_lr: A,
164 lr_multiplier: A,
165 contributing_signals: Vec<AdaptationSignalType>,
166 confidence: A,
167 rationale: String,
168 timestamp: Instant,
169}
170
171#[derive(Debug, Clone)]
173pub struct GradientBasedAdapter<A: Float + Send + Sync> {
174 magnitude_history: VecDeque<A>,
176
177 direction_variance_history: VecDeque<A>,
179
180 norm_statistics: GradientNormStatistics<A>,
182
183 snr_estimator: SignalToNoiseEstimator<A>,
185
186 staleness_detector: GradientStalenessDetector<A>,
188}
189
190#[derive(Debug, Clone)]
192pub struct PerformanceBasedAdapter<A: Float + Send + Sync> {
193 metric_history: HashMap<String, VecDeque<A>>,
195
196 trend_analyzer: PerformanceTrendAnalyzer<A>,
198
199 plateau_detector: PlateauDetector<A>,
201
202 overfitting_detector: OverfittingDetector<A>,
204
205 efficiency_tracker: LearningEfficiencyTracker<A>,
207}
208
209#[derive(Debug, Clone)]
211pub struct DriftAwareAdapter<A: Float + Send + Sync> {
212 drift_detectors: Vec<ConceptDriftDetector<A>>,
214
215 distribution_tracker: DistributionTracker<A>,
217
218 adaptation_speed: AdaptationSpeedController<A>,
220
221 drift_severity: DriftSeverityAssessor<A>,
223}
224
225#[derive(Debug, Clone)]
227pub struct ResourceAwareAdapter<A: Float + Send + Sync> {
228 memory_tracker: MemoryUsageTracker,
230
231 compute_tracker: ComputationTimeTracker,
233
234 energy_tracker: EnergyConsumptionTracker,
236
237 throughput_requirements: ThroughputRequirements<A>,
239
240 budget_manager: ResourceBudgetManager<A>,
242}
243
244#[derive(Debug, Clone)]
246pub struct MetaOptimizer<A: Float + Send + Sync> {
247 lr_predictor: LearningRatePredictorNetwork<A>,
249
250 optimization_history: VecDeque<HyperparameterUpdate<A>>,
252
253 exploration_strategy: ExplorationStrategy<A>,
255
256 transfer_learner: TransferLearner<A>,
258}
259
260#[derive(Debug, Clone)]
262pub struct AdaptationEvent<A: Float + Send + Sync> {
263 timestamp: Instant,
264 old_lr: A,
265 new_lr: A,
266 trigger_signals: Vec<AdaptationSignalType>,
267 adaptation_reason: String,
268 confidence: A,
269 effectiveness_score: Option<A>, }
271
272#[derive(Debug, Clone)]
274pub struct GradientNormStatistics<A: Float + Send + Sync> {
275 mean: A,
276 variance: A,
277 skewness: A,
278 kurtosis: A,
279 percentiles: Vec<A>, autocorrelation: A,
281}
282
283#[derive(Debug, Clone)]
285pub struct SignalToNoiseEstimator<A: Float + Send + Sync> {
286 signal_estimate: A,
287 noise_estimate: A,
288 snr_history: VecDeque<A>,
289 estimation_method: SNREstimationMethod,
290}
291
292#[derive(Debug, Clone, Copy)]
293pub enum SNREstimationMethod {
294 MovingAverage,
295 ExponentialSmoothing,
296 RobustEstimation,
297 WaveletDenoising,
298}
299
300#[derive(Debug, Clone)]
302pub struct GradientStalenessDetector<A: Float + Send + Sync> {
303 staleness_threshold: Duration,
304 gradient_timestamps: VecDeque<Instant>,
305 staleness_impact_model: StalenessImpactModel<A>,
306}
307
308#[derive(Debug, Clone)]
309pub struct StalenessImpactModel<A: Float + Send + Sync> {
310 staleness_penalty: A,
311 compensation_factor: A,
312 impact_history: VecDeque<A>,
313}
314
315#[derive(Debug, Clone)]
317pub struct PerformanceTrendAnalyzer<A: Float + Send + Sync> {
318 trend_detection_window: usize,
319 trend_types: Vec<TrendType>,
320 trend_strength: A,
321 trend_duration: Duration,
322}
323
324#[derive(Debug, Clone, Copy)]
325pub enum TrendType {
326 Improving,
327 Degrading,
328 Oscillating,
329 Plateau,
330 Volatile,
331}
332
333#[derive(Debug, Clone)]
335pub struct PlateauDetector<A: Float + Send + Sync> {
336 plateau_threshold: A,
337 min_plateau_duration: usize,
338 current_plateau_length: usize,
339 plateau_confidence: A,
340}
341
342#[derive(Debug, Clone)]
344pub struct OverfittingDetector<A: Float + Send + Sync> {
345 train_loss_history: VecDeque<A>,
346 val_loss_history: VecDeque<A>,
347 overfitting_threshold: A,
348 early_stopping_patience: usize,
349}
350
351#[derive(Debug, Clone)]
353pub struct LearningEfficiencyTracker<A: Float + Send + Sync> {
354 loss_reduction_per_step: VecDeque<A>,
355 parameter_change_magnitude: VecDeque<A>,
356 efficiency_score: A,
357 efficiency_trend: TrendType,
358}
359
360#[derive(Debug, Clone)]
362pub struct ConceptDriftDetector<A: Float + Send + Sync> {
363 detection_method: DriftDetectionMethod,
364 drift_threshold: A,
365 window_size: usize,
366 drift_confidence: A,
367 last_drift_time: Option<Instant>,
368}
369
370#[derive(Debug, Clone, Copy)]
371pub enum DriftDetectionMethod {
372 ADWIN,
373 DDM,
374 EDDM,
375 PageHinkley,
376 KSWIN,
377 Statistical,
378}
379
380#[derive(Debug, Clone)]
382pub struct DistributionTracker<A: Float + Send + Sync> {
383 feature_distributions: HashMap<usize, FeatureDistribution<A>>,
384 kl_divergence_threshold: A,
385 wasserstein_distance_threshold: A,
386 distribution_drift_score: A,
387}
388
389#[derive(Debug, Clone)]
390pub struct FeatureDistribution<A: Float + Send + Sync> {
391 mean: A,
392 variance: A,
393 histogram: Vec<A>,
394 last_update: Instant,
395}
396
397#[derive(Debug, Clone)]
399pub struct AdaptationSpeedController<A: Float + Send + Sync> {
400 base_adaptation_rate: A,
401 current_adaptation_rate: A,
402 acceleration_factor: A,
403 deceleration_factor: A,
404 momentum: A,
405}
406
407#[derive(Debug, Clone)]
409pub struct DriftSeverityAssessor<A: Float + Send + Sync> {
410 severity_levels: Vec<DriftSeverityLevel<A>>,
411 current_severity: DriftSeverityLevel<A>,
412 severity_history: VecDeque<DriftSeverityLevel<A>>,
413}
414
415#[derive(Debug, Clone)]
416pub struct DriftSeverityLevel<A: Float + Send + Sync> {
417 level: DriftSeverity,
418 magnitude: A,
419 recommended_lr_adjustment: A,
420 adaptation_urgency: A,
421}
422
423#[derive(Debug, Clone, Copy, PartialEq)]
424pub enum DriftSeverity {
425 None,
426 Mild,
427 Moderate,
428 Severe,
429 Critical,
430}
431
432#[derive(Debug, Clone, Default)]
434pub struct MemoryUsageTracker {
435 current_usage_mb: f64,
436 peak_usage_mb: f64,
437 usage_history: VecDeque<f64>,
438 memory_pressure: f64,
439}
440
441#[derive(Debug, Clone, Default)]
442pub struct ComputationTimeTracker {
443 step_times: VecDeque<Duration>,
444 average_step_time: Duration,
445 time_budget: Duration,
446 time_pressure: f64,
447}
448
449#[derive(Debug, Clone, Default)]
450pub struct EnergyConsumptionTracker {
451 energy_per_step: VecDeque<f64>,
452 cumulative_energy: f64,
453 energy_budget: f64,
454 energy_efficiency: f64,
455}
456
457#[derive(Debug, Clone)]
458pub struct ThroughputRequirements<A: Float + Send + Sync> {
459 min_samples_per_second: A,
460 target_samples_per_second: A,
461 current_throughput: A,
462 throughput_deficit: A,
463}
464
465#[derive(Debug, Clone)]
466pub struct ResourceBudgetManager<A: Float + Send + Sync> {
467 memory_budget_mb: f64,
468 compute_budget_seconds: f64,
469 energy_budget_joules: f64,
470 budget_utilization: A,
471 budget_violations: usize,
472}
473
474#[derive(Debug, Clone)]
476pub struct LearningRatePredictorNetwork<A: Float + Send + Sync> {
477 input_features: Vec<FeatureType>,
478 hidden_layers: Vec<usize>,
479 weights: Vec<Array2<A>>,
480 biases: Vec<Array1<A>>,
481 prediction_confidence: A,
482}
483
484#[derive(Debug, Clone, Copy)]
485pub enum FeatureType {
486 GradientNorm,
487 LossValue,
488 LossGradient,
489 ParameterNorm,
490 UpdateMagnitude,
491 LearningProgress,
492 ResourceUtilization,
493 DataCharacteristics,
494}
495
496#[derive(Debug, Clone)]
498pub struct HyperparameterUpdate<A: Float + Send + Sync> {
499 timestamp: Instant,
500 old_lr: A,
501 new_lr: A,
502 features: Array1<A>,
503 reward: A, exploration_bonus: A,
505}
506
507#[derive(Debug, Clone)]
509pub struct ExplorationStrategy<A: Float + Send + Sync> {
510 strategy_type: ExplorationStrategyType,
511 exploration_rate: A,
512 exploitation_rate: A,
513 arm_rewards: HashMap<usize, A>,
514 arm_counts: HashMap<usize, usize>,
515}
516
517#[derive(Debug, Clone, Copy)]
518pub enum ExplorationStrategyType {
519 EpsilonGreedy,
520 UCB1,
521 ThompsonSampling,
522 LinUCB,
523 ContextualBandit,
524}
525
526#[derive(Debug, Clone)]
528pub struct TransferLearner<A: Float + Send + Sync> {
529 source_task_data: Vec<TaskData<A>>,
530 similarity_metrics: Vec<TaskSimilarityMetric<A>>,
531 transfer_weights: Array1<A>,
532 transfer_confidence: A,
533}
534
535#[derive(Debug, Clone)]
536pub struct TaskData<A: Float + Send + Sync> {
537 task_id: String,
538 optimal_lr_sequence: Vec<A>,
539 task_features: Array1<A>,
540 performance_curve: Vec<A>,
541}
542
543#[derive(Debug, Clone)]
544pub struct TaskSimilarityMetric<A: Float + Send + Sync> {
545 metric_type: SimilarityMetricType,
546 similarity_score: A,
547 weight: A,
548}
549
550#[derive(Debug, Clone, Copy)]
551pub enum SimilarityMetricType {
552 DatasetSize,
553 ModelArchitecture,
554 LossFunction,
555 DataDistribution,
556 OptimizationLandscape,
557}
558
559#[derive(Debug, Clone, Default)]
561pub struct AdaptationStatistics<A: Float + Send + Sync> {
562 pub total_adaptations: usize,
564
565 pub successful_adaptations: usize,
567
568 pub avg_adaptation_frequency: A,
570
571 pub lr_volatility: A,
573
574 pub signal_reliability_scores: HashMap<AdaptationSignalType, A>,
576
577 pub signal_effectiveness: HashMap<AdaptationSignalType, A>,
579
580 pub resource_efficiency_gains: A,
582
583 pub convergence_speed_improvement: A,
585}
586
587impl<A: Float + Default + Clone + Send + Sync + Send + Sync> EnhancedAdaptiveLRController<A> {
588 pub fn new(config: AdaptiveLRConfig<A>) -> Result<Self> {
590 let adaptation_strategy = MultiSignalAdaptationStrategy::new(&config)?;
591 let gradient_adapter = GradientBasedAdapter::new(&config)?;
592 let performance_adapter = PerformanceBasedAdapter::new(&config)?;
593 let drift_adapter = DriftAwareAdapter::new(&config)?;
594 let resource_adapter = ResourceAwareAdapter::new(&config)?;
595 let meta_optimizer = MetaOptimizer::new(&config)?;
596
597 Ok(Self {
598 current_lr: config.base_lr,
599 base_lr: config.base_lr,
600 min_lr: config.min_lr,
601 max_lr: config.max_lr,
602 adaptation_strategy,
603 gradient_adapter,
604 performance_adapter,
605 drift_adapter,
606 resource_adapter,
607 meta_optimizer,
608 adaptation_history: VecDeque::with_capacity(config.history_window_size),
609 config,
610 })
611 }
612
613 pub fn update_learning_rate(
615 &mut self,
616 gradients: &Array1<A>,
617 loss: A,
618 metrics: &HashMap<String, A>,
619 step: usize,
620 ) -> Result<A> {
621 let mut signals = Vec::new();
623
624 if self.config.enable_gradient_adaptation {
625 if let Ok(signal) = self.gradient_adapter.generate_signal(gradients, step) {
626 signals.push(signal);
627 }
628 }
629
630 if self.config.enable_performance_adaptation {
631 if let Ok(signal) = self
632 .performance_adapter
633 .generate_signal(loss, metrics, step)
634 {
635 signals.push(signal);
636 }
637 }
638
639 if self.config.enable_drift_adaptation {
640 if let Ok(signal) = self.drift_adapter.generate_signal(gradients, step) {
641 signals.push(signal);
642 }
643 }
644
645 if self.config.enable_resource_adaptation {
646 if let Ok(signal) = self.resource_adapter.generate_signal(step) {
647 signals.push(signal);
648 }
649 }
650
651 let decision = self.adaptation_strategy.resolve_signals(signals, step)?;
653
654 if self.config.enable_meta_learning {
656 let meta_adjustment = self.meta_optimizer.meta_optimize(&decision, step)?;
657 self.current_lr = self.apply_meta_adjustment(decision.new_lr, meta_adjustment);
658 } else {
659 self.current_lr = decision.new_lr;
660 }
661
662 self.current_lr = self
664 .current_lr
665 .clamp(self.config.min_lr, self.config.max_lr);
666
667 let event = AdaptationEvent {
669 timestamp: Instant::now(),
670 old_lr: decision.new_lr, new_lr: self.current_lr,
672 trigger_signals: decision.contributing_signals,
673 adaptation_reason: decision.rationale,
674 confidence: decision.confidence,
675 effectiveness_score: None, };
677
678 self.adaptation_history.push_back(event);
679 if self.adaptation_history.len() > self.config.history_window_size {
680 self.adaptation_history.pop_front();
681 }
682
683 Ok(self.current_lr)
684 }
685
686 pub fn get_current_lr(&self) -> A {
688 self.current_lr
689 }
690
691 pub fn get_adaptation_statistics(&self) -> AdaptationStatistics<A> {
693 let total_adaptations = self.adaptation_history.len();
694 let successful_adaptations = self
695 .adaptation_history
696 .iter()
697 .filter(|event| {
698 event
699 .effectiveness_score
700 .is_some_and(|score| score > A::zero())
701 })
702 .count();
703
704 let lr_volatility = if !self.adaptation_history.is_empty() {
705 let lr_values: Vec<A> = self
706 .adaptation_history
707 .iter()
708 .map(|event| event.new_lr)
709 .collect();
710
711 let mean_lr = lr_values.iter().fold(A::zero(), |acc, &lr| acc + lr)
712 / A::from(lr_values.len()).unwrap();
713
714 let variance = lr_values
715 .iter()
716 .map(|&lr| {
717 let diff = lr - mean_lr;
718 diff * diff
719 })
720 .fold(A::zero(), |acc, var| acc + var)
721 / A::from(lr_values.len()).unwrap();
722
723 variance.sqrt()
724 } else {
725 A::zero()
726 };
727
728 AdaptationStatistics {
729 total_adaptations,
730 successful_adaptations,
731 lr_volatility,
732 ..Default::default()
733 }
734 }
735
736 fn apply_meta_adjustment(&self, base_lr: A, meta_adjustment: A) -> A {
738 let alpha = A::from(0.7).unwrap(); let beta = A::from(0.3).unwrap(); alpha * base_lr + beta * meta_adjustment
743 }
744
745 pub fn evaluate_adaptation_effectiveness(&mut self, performance_improvement: A) {
747 if let Some(last_event) = self.adaptation_history.back_mut() {
748 last_event.effectiveness_score = Some(performance_improvement);
749
750 for signal_type in &last_event.trigger_signals {
752 self.adaptation_strategy
753 .update_signal_reliability(*signal_type, performance_improvement);
754 }
755 }
756 }
757
758 pub fn reset(&mut self) {
760 self.current_lr = self.base_lr;
761 self.adaptation_history.clear();
762 self.gradient_adapter.reset();
763 self.performance_adapter.reset();
764 self.drift_adapter.reset();
765 self.resource_adapter.reset();
766 self.meta_optimizer.reset();
767 }
768}
769
770impl<A: Float + Default + Clone + Send + Sync + Send + Sync> MultiSignalAdaptationStrategy<A> {
774 fn new(config: &AdaptiveLRConfig<A>) -> Result<Self> {
775 Ok(Self {
776 signal_weights: HashMap::new(),
777 voting_history: VecDeque::new(),
778 conflict_resolution: ConflictResolution::WeightedAverage,
779 signal_reliability: HashMap::new(),
780 last_decision: None,
781 })
782 }
783
784 fn resolve_signals(
785 &mut self,
786 signals: Vec<SignalVote<A>>,
787 _step: usize,
788 ) -> Result<AdaptationDecision<A>> {
789 if signals.is_empty() {
790 return Ok(AdaptationDecision {
791 new_lr: A::from(0.001).unwrap(),
792 lr_multiplier: A::one(),
793 contributing_signals: vec![],
794 confidence: A::zero(),
795 rationale: "No signals available".to_string(),
796 timestamp: Instant::now(),
797 });
798 }
799
800 let total_weight = signals
802 .iter()
803 .map(|s| s.confidence)
804 .fold(A::zero(), |acc, c| acc + c);
805
806 let weighted_change = signals
807 .iter()
808 .map(|s| s.recommended_lr_change * s.confidence)
809 .fold(A::zero(), |acc, change| acc + change)
810 / total_weight;
811
812 let contributing_signals = signals.iter().map(|s| s.signal_type).collect();
813
814 Ok(AdaptationDecision {
815 new_lr: A::from(0.001).unwrap() * weighted_change,
816 lr_multiplier: weighted_change,
817 contributing_signals,
818 confidence: total_weight / A::from(signals.len()).unwrap(),
819 rationale: "Weighted average of adaptation signals".to_string(),
820 timestamp: Instant::now(),
821 })
822 }
823
824 fn update_signal_reliability(&mut self, signal_type: AdaptationSignalType, effectiveness: A) {
825 let reliability = self
826 .signal_reliability
827 .entry(signal_type)
828 .or_insert(A::from(0.5).unwrap());
829
830 let alpha = A::from(0.1).unwrap();
832 *reliability = (*reliability) * (A::one() - alpha) + effectiveness * alpha;
833 }
834}
835
836impl<A: Float + Default + Clone + Send + Sync + Send + Sync> GradientBasedAdapter<A> {
837 fn new(config: &AdaptiveLRConfig<A>) -> Result<Self> {
838 Ok(Self {
839 magnitude_history: VecDeque::new(),
840 direction_variance_history: VecDeque::new(),
841 norm_statistics: GradientNormStatistics::default(),
842 snr_estimator: SignalToNoiseEstimator::default(),
843 staleness_detector: GradientStalenessDetector::default(),
844 })
845 }
846
847 fn generate_signal(&mut self, gradients: &Array1<A>, step: usize) -> Result<SignalVote<A>> {
848 let magnitude = gradients
849 .iter()
850 .map(|&g| g * g)
851 .fold(A::zero(), |acc, x| acc + x)
852 .sqrt();
853 self.magnitude_history.push_back(magnitude);
854
855 if self.magnitude_history.len() > 100 {
856 self.magnitude_history.pop_front();
857 }
858
859 let recommended_change = if magnitude > A::from(1.0).unwrap() {
861 A::from(0.9).unwrap() } else if magnitude < A::from(0.01).unwrap() {
863 A::from(1.1).unwrap() } else {
865 A::one() };
867
868 Ok(SignalVote {
869 signal_type: AdaptationSignalType::GradientMagnitude,
870 recommended_lr_change: recommended_change,
871 confidence: A::from(0.7).unwrap(),
872 reasoning: "Gradient magnitude-based adaptation".to_string(),
873 timestamp: Instant::now(),
874 })
875 }
876
877 fn reset(&mut self) {
878 self.magnitude_history.clear();
879 self.direction_variance_history.clear();
880 }
881}
882
883impl<A: Float + Default + Clone + Send + Sync + Send + Sync> PerformanceBasedAdapter<A> {
884 fn new(config: &AdaptiveLRConfig<A>) -> Result<Self> {
885 Ok(Self {
886 metric_history: HashMap::new(),
887 trend_analyzer: PerformanceTrendAnalyzer::default(),
888 plateau_detector: PlateauDetector::default(),
889 overfitting_detector: OverfittingDetector::default(),
890 efficiency_tracker: LearningEfficiencyTracker::default(),
891 })
892 }
893
894 fn generate_signal(
895 &mut self,
896 loss: A,
897 metrics: &HashMap<String, A>,
898 _step: usize,
899 ) -> Result<SignalVote<A>> {
900 let loss_history = self.metric_history.entry("loss".to_string()).or_default();
901
902 loss_history.push_back(loss);
903 if loss_history.len() > 50 {
904 loss_history.pop_front();
905 }
906
907 let recommended_change = if loss_history.len() >= 2 {
909 let recent_loss = loss_history.back().unwrap();
910 let prev_loss = loss_history.get(loss_history.len() - 2).unwrap();
911
912 if *recent_loss > *prev_loss {
913 A::from(0.95).unwrap() } else {
915 A::from(1.02).unwrap() }
917 } else {
918 A::one()
919 };
920
921 Ok(SignalVote {
922 signal_type: AdaptationSignalType::LossProgression,
923 recommended_lr_change: recommended_change,
924 confidence: A::from(0.8).unwrap(),
925 reasoning: "Loss progression analysis".to_string(),
926 timestamp: Instant::now(),
927 })
928 }
929
930 fn reset(&mut self) {
931 self.metric_history.clear();
932 }
933}
934
935impl<A: Float + Default + Clone + Send + Sync + Send + Sync> DriftAwareAdapter<A> {
936 fn new(config: &AdaptiveLRConfig<A>) -> Result<Self> {
937 Ok(Self {
938 drift_detectors: vec![],
939 distribution_tracker: DistributionTracker::default(),
940 adaptation_speed: AdaptationSpeedController::default(),
941 drift_severity: DriftSeverityAssessor::default(),
942 })
943 }
944
945 fn generate_signal(&mut self, gradients: &Array1<A>, step: usize) -> Result<SignalVote<A>> {
946 Ok(SignalVote {
948 signal_type: AdaptationSignalType::ConceptDrift,
949 recommended_lr_change: A::one(),
950 confidence: A::from(0.5).unwrap(),
951 reasoning: "No drift detected".to_string(),
952 timestamp: Instant::now(),
953 })
954 }
955
956 fn reset(&mut self) {
957 }
959}
960
961impl<A: Float + Default + Clone + Send + Sync + Send + Sync> ResourceAwareAdapter<A> {
962 fn new(config: &AdaptiveLRConfig<A>) -> Result<Self> {
963 Ok(Self {
964 memory_tracker: MemoryUsageTracker::default(),
965 compute_tracker: ComputationTimeTracker::default(),
966 energy_tracker: EnergyConsumptionTracker::default(),
967 throughput_requirements: ThroughputRequirements {
968 min_samples_per_second: A::from(100.0).unwrap(),
969 target_samples_per_second: A::from(1000.0).unwrap(),
970 current_throughput: A::from(500.0).unwrap(),
971 throughput_deficit: A::zero(),
972 },
973 budget_manager: ResourceBudgetManager {
974 memory_budget_mb: 1000.0,
975 compute_budget_seconds: 3600.0,
976 energy_budget_joules: 1000.0,
977 budget_utilization: A::from(0.5).unwrap(),
978 budget_violations: 0,
979 },
980 })
981 }
982
983 fn generate_signal(&mut self, step: usize) -> Result<SignalVote<A>> {
984 let memory_pressure = self.memory_tracker.memory_pressure;
986
987 let recommended_change = if memory_pressure > 0.8 {
988 A::from(0.9).unwrap() } else if memory_pressure < 0.3 {
990 A::from(1.05).unwrap() } else {
992 A::one()
993 };
994
995 Ok(SignalVote {
996 signal_type: AdaptationSignalType::ResourceUtilization,
997 recommended_lr_change: recommended_change,
998 confidence: A::from(0.6).unwrap(),
999 reasoning: format!("Memory pressure: {:.2}", memory_pressure),
1000 timestamp: Instant::now(),
1001 })
1002 }
1003
1004 fn reset(&mut self) {
1005 self.memory_tracker = MemoryUsageTracker::default();
1006 self.compute_tracker = ComputationTimeTracker::default();
1007 self.energy_tracker = EnergyConsumptionTracker::default();
1008 }
1009}
1010
1011impl<A: Float + Default + Clone + Send + Sync + Send + Sync> MetaOptimizer<A> {
1012 fn new(config: &AdaptiveLRConfig<A>) -> Result<Self> {
1013 Ok(Self {
1014 lr_predictor: LearningRatePredictorNetwork::default(),
1015 optimization_history: VecDeque::new(),
1016 exploration_strategy: ExplorationStrategy::default(),
1017 transfer_learner: TransferLearner::default(),
1018 })
1019 }
1020
1021 fn meta_optimize(&mut self, decision: &AdaptationDecision<A>, step: usize) -> Result<A> {
1022 Ok(A::from(0.001).unwrap())
1024 }
1025
1026 fn reset(&mut self) {
1027 self.optimization_history.clear();
1028 }
1029}
1030
1031impl<A: Float + Default + Send + Sync + Send + Sync> Default for GradientNormStatistics<A> {
1033 fn default() -> Self {
1034 Self {
1035 mean: A::default(),
1036 variance: A::default(),
1037 skewness: A::default(),
1038 kurtosis: A::default(),
1039 percentiles: vec![A::default(); 5],
1040 autocorrelation: A::default(),
1041 }
1042 }
1043}
1044
1045impl<A: Float + Default + Send + Sync + Send + Sync> Default for SignalToNoiseEstimator<A> {
1046 fn default() -> Self {
1047 Self {
1048 signal_estimate: A::default(),
1049 noise_estimate: A::default(),
1050 snr_history: VecDeque::new(),
1051 estimation_method: SNREstimationMethod::MovingAverage,
1052 }
1053 }
1054}
1055
1056impl<A: Float + Default + Send + Sync + Send + Sync> Default for GradientStalenessDetector<A> {
1057 fn default() -> Self {
1058 Self {
1059 staleness_threshold: Duration::from_secs(1),
1060 gradient_timestamps: VecDeque::new(),
1061 staleness_impact_model: StalenessImpactModel::default(),
1062 }
1063 }
1064}
1065
1066impl<A: Float + Default + Send + Sync + Send + Sync> Default for StalenessImpactModel<A> {
1067 fn default() -> Self {
1068 Self {
1069 staleness_penalty: A::default(),
1070 compensation_factor: A::default(),
1071 impact_history: VecDeque::new(),
1072 }
1073 }
1074}
1075
1076impl<A: Float + Default + Send + Sync + Send + Sync> Default for PerformanceTrendAnalyzer<A> {
1077 fn default() -> Self {
1078 Self {
1079 trend_detection_window: 10,
1080 trend_types: vec![],
1081 trend_strength: A::default(),
1082 trend_duration: Duration::from_secs(0),
1083 }
1084 }
1085}
1086
1087impl<A: Float + Default + Send + Sync + Send + Sync> Default for PlateauDetector<A> {
1088 fn default() -> Self {
1089 Self {
1090 plateau_threshold: A::default(),
1091 min_plateau_duration: 5,
1092 current_plateau_length: 0,
1093 plateau_confidence: A::default(),
1094 }
1095 }
1096}
1097
1098impl<A: Float + Default + Send + Sync + Send + Sync> Default for OverfittingDetector<A> {
1099 fn default() -> Self {
1100 Self {
1101 train_loss_history: VecDeque::new(),
1102 val_loss_history: VecDeque::new(),
1103 overfitting_threshold: A::default(),
1104 early_stopping_patience: 10,
1105 }
1106 }
1107}
1108
1109impl<A: Float + Default + Send + Sync + Send + Sync> Default for LearningEfficiencyTracker<A> {
1110 fn default() -> Self {
1111 Self {
1112 loss_reduction_per_step: VecDeque::new(),
1113 parameter_change_magnitude: VecDeque::new(),
1114 efficiency_score: A::default(),
1115 efficiency_trend: TrendType::Improving,
1116 }
1117 }
1118}
1119
1120impl<A: Float + Default + Send + Sync + Send + Sync> Default for DistributionTracker<A> {
1121 fn default() -> Self {
1122 Self {
1123 feature_distributions: HashMap::new(),
1124 kl_divergence_threshold: A::default(),
1125 wasserstein_distance_threshold: A::default(),
1126 distribution_drift_score: A::default(),
1127 }
1128 }
1129}
1130
1131impl<A: Float + Default + Send + Sync + Send + Sync> Default for AdaptationSpeedController<A> {
1132 fn default() -> Self {
1133 Self {
1134 base_adaptation_rate: A::from(0.1).unwrap_or_default(),
1135 current_adaptation_rate: A::from(0.1).unwrap_or_default(),
1136 acceleration_factor: A::from(1.1).unwrap_or_default(),
1137 deceleration_factor: A::from(0.9).unwrap_or_default(),
1138 momentum: A::default(),
1139 }
1140 }
1141}
1142
1143impl<A: Float + Default + Send + Sync + Send + Sync> Default for DriftSeverityAssessor<A> {
1144 fn default() -> Self {
1145 Self {
1146 severity_levels: vec![],
1147 current_severity: DriftSeverityLevel::default(),
1148 severity_history: VecDeque::new(),
1149 }
1150 }
1151}
1152
1153impl<A: Float + Default + Send + Sync + Send + Sync> Default for DriftSeverityLevel<A> {
1154 fn default() -> Self {
1155 Self {
1156 level: DriftSeverity::None,
1157 magnitude: A::default(),
1158 recommended_lr_adjustment: A::one(),
1159 adaptation_urgency: A::default(),
1160 }
1161 }
1162}
1163
1164impl<A: Float + Default + Send + Sync + Send + Sync> Default for LearningRatePredictorNetwork<A> {
1165 fn default() -> Self {
1166 Self {
1167 input_features: vec![],
1168 hidden_layers: vec![],
1169 weights: vec![],
1170 biases: vec![],
1171 prediction_confidence: A::default(),
1172 }
1173 }
1174}
1175
1176impl<A: Float + Default + Send + Sync + Send + Sync> Default for ExplorationStrategy<A> {
1177 fn default() -> Self {
1178 Self {
1179 strategy_type: ExplorationStrategyType::EpsilonGreedy,
1180 exploration_rate: A::from(0.1).unwrap_or_default(),
1181 exploitation_rate: A::from(0.9).unwrap_or_default(),
1182 arm_rewards: HashMap::new(),
1183 arm_counts: HashMap::new(),
1184 }
1185 }
1186}
1187
1188impl<A: Float + Default + Send + Sync + Send + Sync> Default for TransferLearner<A> {
1189 fn default() -> Self {
1190 Self {
1191 source_task_data: vec![],
1192 similarity_metrics: vec![],
1193 transfer_weights: Array1::from_vec(vec![]),
1194 transfer_confidence: A::default(),
1195 }
1196 }
1197}
1198
1199#[cfg(test)]
1200mod tests {
1201 use super::*;
1202 use scirs2_core::ndarray::Array1;
1203
1204 #[test]
1205 fn test_enhanced_adaptive_lr_controller_creation() {
1206 let config = AdaptiveLRConfig {
1207 base_lr: 0.01,
1208 min_lr: 1e-6,
1209 max_lr: 1.0,
1210 enable_gradient_adaptation: true,
1211 enable_performance_adaptation: true,
1212 enable_drift_adaptation: false,
1213 enable_resource_adaptation: false,
1214 enable_meta_learning: false,
1215 history_window_size: 100,
1216 adaptation_frequency: 10,
1217 adaptation_sensitivity: 0.1,
1218 use_ensemble_voting: true,
1219 };
1220
1221 let controller = EnhancedAdaptiveLRController::<f32>::new(config);
1222 assert!(controller.is_ok());
1223 }
1224
1225 #[test]
1226 fn test_learning_rate_update() {
1227 let config = AdaptiveLRConfig {
1228 base_lr: 0.01,
1229 min_lr: 1e-6,
1230 max_lr: 1.0,
1231 enable_gradient_adaptation: true,
1232 enable_performance_adaptation: true,
1233 enable_drift_adaptation: false,
1234 enable_resource_adaptation: false,
1235 enable_meta_learning: false,
1236 history_window_size: 100,
1237 adaptation_frequency: 10,
1238 adaptation_sensitivity: 0.1,
1239 use_ensemble_voting: true,
1240 };
1241
1242 let mut controller = EnhancedAdaptiveLRController::<f32>::new(config).unwrap();
1243 let gradients = Array1::from_vec(vec![0.1, 0.2, 0.05]);
1244 let loss = 0.5;
1245 let metrics = HashMap::new();
1246
1247 let new_lr = controller.update_learning_rate(&gradients, loss, &metrics, 1);
1248 assert!(new_lr.is_ok());
1249 assert!(new_lr.unwrap() > 0.0);
1250 }
1251
1252 #[test]
1253 fn test_adaptation_statistics() {
1254 let config = AdaptiveLRConfig {
1255 base_lr: 0.01,
1256 min_lr: 1e-6,
1257 max_lr: 1.0,
1258 enable_gradient_adaptation: true,
1259 enable_performance_adaptation: true,
1260 enable_drift_adaptation: false,
1261 enable_resource_adaptation: false,
1262 enable_meta_learning: false,
1263 history_window_size: 100,
1264 adaptation_frequency: 10,
1265 adaptation_sensitivity: 0.1,
1266 use_ensemble_voting: true,
1267 };
1268
1269 let controller = EnhancedAdaptiveLRController::<f32>::new(config).unwrap();
1270 let stats = controller.get_adaptation_statistics();
1271
1272 assert_eq!(stats.total_adaptations, 0);
1273 assert_eq!(stats.successful_adaptations, 0);
1274 }
1275}