optirs_learned/transformer/training/
evaluation.rs

1use std::fmt::Debug;
2// Evaluation metrics and methods for transformer optimization
3//
4// This module implements comprehensive evaluation strategies for assessing
5// the performance of transformer-based learned optimizers across various metrics.
6
7#[allow(dead_code)]
8use scirs2_core::ndarray::{Array1, Array2, Array3};
9use scirs2_core::numeric::Float;
10use std::collections::{HashMap, VecDeque};
11
12use crate::error::{OptimError, Result};
13
14/// Evaluation strategies for transformer optimizers
15#[derive(Debug, Clone, Copy)]
16pub enum EvaluationStrategy {
17    /// Single-task evaluation
18    SingleTask,
19    /// Multi-task evaluation
20    MultiTask,
21    /// Cross-domain evaluation
22    CrossDomain,
23    /// Few-shot evaluation
24    FewShot,
25    /// Continual learning evaluation
26    ContinualLearning,
27    /// Robustness evaluation
28    Robustness,
29    /// Efficiency evaluation
30    Efficiency,
31    /// Comprehensive evaluation
32    Comprehensive,
33}
34
35/// Performance evaluator for transformer optimizers
36#[derive(Debug, Clone)]
37pub struct TransformerEvaluator<T: Float + Debug + Send + Sync + 'static> {
38    /// Evaluation strategy
39    strategy: EvaluationStrategy,
40
41    /// Evaluation parameters
42    eval_params: EvaluationParams<T>,
43
44    /// Metric calculators
45    metric_calculators: HashMap<String, MetricCalculator<T>>,
46
47    /// Performance history
48    performance_history: VecDeque<EvaluationResult<T>>,
49
50    /// Baseline comparisons
51    baseline_comparisons: HashMap<String, BaselineComparison<T>>,
52
53    /// Statistical analyzers
54    statistical_analyzers: Vec<StatisticalAnalyzer<T>>,
55}
56
57/// Evaluation parameters
58#[derive(Debug, Clone)]
59pub struct EvaluationParams<T: Float + Debug + Send + Sync + 'static> {
60    /// Number of evaluation episodes
61    num_episodes: usize,
62
63    /// Evaluation frequency
64    eval_frequency: usize,
65
66    /// Convergence tolerance
67    convergence_tolerance: T,
68
69    /// Maximum evaluation steps
70    max_eval_steps: usize,
71
72    /// Confidence level for statistical tests
73    confidence_level: T,
74
75    /// Number of bootstrap samples
76    bootstrap_samples: usize,
77
78    /// Cross-validation folds
79    cv_folds: usize,
80
81    /// Robustness test severity
82    robustness_severity: T,
83}
84
85/// Evaluation result
86#[derive(Debug, Clone)]
87pub struct EvaluationResult<T: Float + Debug + Send + Sync + 'static> {
88    /// Evaluation identifier
89    eval_id: String,
90
91    /// Task identifier
92    task_id: String,
93
94    /// Performance metrics
95    metrics: HashMap<String, T>,
96
97    /// Convergence information
98    convergence_info: ConvergenceInfo<T>,
99
100    /// Efficiency metrics
101    efficiency_metrics: EfficiencyMetrics<T>,
102
103    /// Statistical significance
104    statistical_significance: StatisticalSignificance<T>,
105
106    /// Evaluation timestamp
107    timestamp: usize,
108}
109
110/// Convergence information
111#[derive(Debug, Clone)]
112pub struct ConvergenceInfo<T: Float + Debug + Send + Sync + 'static> {
113    /// Whether convergence was achieved
114    converged: bool,
115
116    /// Number of steps to convergence
117    steps_to_convergence: Option<usize>,
118
119    /// Final loss value
120    final_loss: T,
121
122    /// Convergence rate
123    convergence_rate: T,
124
125    /// Loss trajectory
126    loss_trajectory: Vec<T>,
127
128    /// Gradient norms
129    gradient_norms: Vec<T>,
130}
131
132/// Efficiency metrics
133#[derive(Debug, Clone)]
134pub struct EfficiencyMetrics<T: Float + Debug + Send + Sync + 'static> {
135    /// Wall-clock time
136    wall_time: T,
137
138    /// Computational FLOPs
139    flops: u64,
140
141    /// Memory usage peak
142    peak_memory: u64,
143
144    /// Parameter efficiency
145    parameter_efficiency: T,
146
147    /// Sample efficiency
148    sample_efficiency: T,
149
150    /// Energy consumption estimate
151    energy_consumption: T,
152}
153
154/// Statistical significance analysis
155#[derive(Debug, Clone)]
156pub struct StatisticalSignificance<T: Float + Debug + Send + Sync + 'static> {
157    /// P-value for performance comparison
158    p_value: T,
159
160    /// Effect size
161    effect_size: T,
162
163    /// Confidence interval
164    confidence_interval: (T, T),
165
166    /// Statistical power
167    statistical_power: T,
168
169    /// Test statistic
170    test_statistic: T,
171}
172
173/// Metric calculator for specific metrics
174#[derive(Debug, Clone)]
175pub struct MetricCalculator<T: Float + Debug + Send + Sync + 'static> {
176    /// Metric name
177    metric_name: String,
178
179    /// Calculation function parameters
180    calculation_params: HashMap<String, T>,
181
182    /// Historical values for trend analysis
183    historical_values: VecDeque<T>,
184
185    /// Aggregation method
186    aggregation_method: AggregationMethod,
187}
188
189/// Baseline comparison data
190#[derive(Debug, Clone)]
191pub struct BaselineComparison<T: Float + Debug + Send + Sync + 'static> {
192    /// Baseline name
193    baseline_name: String,
194
195    /// Baseline performance
196    baseline_performance: HashMap<String, T>,
197
198    /// Improvement over baseline
199    improvement: HashMap<String, T>,
200
201    /// Relative performance
202    relative_performance: HashMap<String, T>,
203
204    /// Win rate against baseline
205    win_rate: T,
206}
207
208/// Statistical analyzer for performance data
209#[derive(Debug, Clone)]
210pub struct StatisticalAnalyzer<T: Float + Debug + Send + Sync + 'static> {
211    /// Analyzer name
212    analyzer_name: String,
213
214    /// Analysis parameters
215    analysis_params: HashMap<String, T>,
216
217    /// Results cache
218    results_cache: HashMap<String, T>,
219}
220
221/// Robustness test suite
222#[derive(Debug, Clone)]
223pub struct RobustnessTestSuite<T: Float + Debug + Send + Sync + 'static> {
224    /// Noise injection tests
225    noise_tests: Vec<NoiseTest<T>>,
226
227    /// Adversarial perturbation tests
228    adversarial_tests: Vec<AdversarialTest<T>>,
229
230    /// Hyperparameter sensitivity tests
231    sensitivity_tests: Vec<SensitivityTest<T>>,
232
233    /// Distribution shift tests
234    distribution_tests: Vec<DistributionTest<T>>,
235}
236
237/// Individual robustness tests
238#[derive(Debug, Clone)]
239pub struct NoiseTest<T: Float + Debug + Send + Sync + 'static> {
240    noise_type: NoiseType,
241    noise_level: T,
242    performance_degradation: T,
243}
244
245#[derive(Debug, Clone)]
246pub struct AdversarialTest<T: Float + Debug + Send + Sync + 'static> {
247    attack_type: AttackType,
248    attack_strength: T,
249    robustness_score: T,
250}
251
252#[derive(Debug, Clone)]
253pub struct SensitivityTest<T: Float + Debug + Send + Sync + 'static> {
254    parameter_name: String,
255    parameter_range: (T, T),
256    sensitivity_score: T,
257}
258
259#[derive(Debug, Clone)]
260pub struct DistributionTest<T: Float + Debug + Send + Sync + 'static> {
261    shift_type: DistributionShiftType,
262    shift_magnitude: T,
263    adaptation_score: T,
264}
265
266/// Aggregation methods for metrics
267#[derive(Debug, Clone, Copy)]
268pub enum AggregationMethod {
269    Mean,
270    Median,
271    Max,
272    Min,
273    WeightedAverage,
274    ExponentialMovingAverage,
275    Percentile(u8),
276}
277
278/// Noise types for robustness testing
279#[derive(Debug, Clone, Copy)]
280pub enum NoiseType {
281    Gaussian,
282    Uniform,
283    SaltPepper,
284    Dropout,
285}
286
287/// Attack types for adversarial testing
288#[derive(Debug, Clone, Copy)]
289pub enum AttackType {
290    FGSM,
291    PGD,
292    CarliniWagner,
293    DeepFool,
294}
295
296/// Distribution shift types
297#[derive(Debug, Clone, Copy)]
298pub enum DistributionShiftType {
299    CovariateShift,
300    ConceptDrift,
301    DatasetShift,
302    TemporalShift,
303}
304
305impl<T: Float + Debug + Default + Clone + Send + Sync + 'static> TransformerEvaluator<T> {
306    /// Create new transformer evaluator
307    pub fn new(strategy: EvaluationStrategy) -> Result<Self> {
308        let mut metric_calculators = HashMap::new();
309
310        // Initialize standard metric calculators
311        metric_calculators.insert(
312            "convergence_speed".to_string(),
313            MetricCalculator::new("convergence_speed".to_string(), AggregationMethod::Mean)?,
314        );
315        metric_calculators.insert(
316            "final_performance".to_string(),
317            MetricCalculator::new("final_performance".to_string(), AggregationMethod::Mean)?,
318        );
319        metric_calculators.insert(
320            "sample_efficiency".to_string(),
321            MetricCalculator::new("sample_efficiency".to_string(), AggregationMethod::Mean)?,
322        );
323
324        Ok(Self {
325            strategy,
326            eval_params: EvaluationParams::default(),
327            metric_calculators,
328            performance_history: VecDeque::new(),
329            baseline_comparisons: HashMap::new(),
330            statistical_analyzers: Vec::new(),
331        })
332    }
333
334    /// Evaluate transformer optimizer performance
335    pub fn evaluate(
336        &mut self,
337        task_id: &str,
338        loss_trajectory: &[T],
339        gradient_norms: &[T],
340        wall_time: T,
341        memory_usage: u64,
342    ) -> Result<EvaluationResult<T>> {
343        let eval_id = format!("eval_{}_{}", task_id, self.performance_history.len());
344
345        // Compute convergence information
346        let convergence_info = self.compute_convergence_info(loss_trajectory, gradient_norms)?;
347
348        // Compute efficiency metrics
349        let efficiency_metrics =
350            self.compute_efficiency_metrics(wall_time, memory_usage, loss_trajectory.len())?;
351
352        // Compute performance metrics
353        let mut metrics = HashMap::new();
354        metrics.insert("final_loss".to_string(), convergence_info.final_loss);
355        metrics.insert(
356            "convergence_rate".to_string(),
357            convergence_info.convergence_rate,
358        );
359        metrics.insert(
360            "sample_efficiency".to_string(),
361            efficiency_metrics.sample_efficiency,
362        );
363
364        // Update metric calculators
365        for (metric_name, metric_value) in &metrics {
366            if let Some(calculator) = self.metric_calculators.get_mut(metric_name) {
367                calculator.update(*metric_value)?;
368            }
369        }
370
371        // Compute statistical significance if baseline exists
372        let statistical_significance = self.compute_statistical_significance(&metrics)?;
373
374        let result = EvaluationResult {
375            eval_id,
376            task_id: task_id.to_string(),
377            metrics,
378            convergence_info,
379            efficiency_metrics,
380            statistical_significance,
381            timestamp: self.performance_history.len(),
382        };
383
384        self.performance_history.push_back(result.clone());
385        if self.performance_history.len() > 1000 {
386            self.performance_history.pop_front();
387        }
388
389        Ok(result)
390    }
391
392    /// Compute convergence information from trajectories
393    fn compute_convergence_info(
394        &self,
395        loss_trajectory: &[T],
396        gradient_norms: &[T],
397    ) -> Result<ConvergenceInfo<T>> {
398        if loss_trajectory.is_empty() {
399            return Err(OptimError::InvalidConfig(
400                "Empty loss trajectory".to_string(),
401            ));
402        }
403
404        let final_loss = *loss_trajectory.last().unwrap();
405        let initial_loss = loss_trajectory[0];
406
407        // Detect convergence
408        let (converged, steps_to_convergence) = self.detect_convergence(loss_trajectory)?;
409
410        // Compute convergence rate
411        let convergence_rate = if loss_trajectory.len() > 1 {
412            let improvement = (initial_loss - final_loss)
413                / initial_loss
414                    .max(scirs2_core::numeric::NumCast::from(1e-8).unwrap_or_else(|| T::zero()));
415            improvement / T::from(loss_trajectory.len() as f64).unwrap()
416        } else {
417            T::zero()
418        };
419
420        Ok(ConvergenceInfo {
421            converged,
422            steps_to_convergence,
423            final_loss,
424            convergence_rate,
425            loss_trajectory: loss_trajectory.to_vec(),
426            gradient_norms: gradient_norms.to_vec(),
427        })
428    }
429
430    /// Detect convergence from loss trajectory
431    fn detect_convergence(&self, loss_trajectory: &[T]) -> Result<(bool, Option<usize>)> {
432        let window_size = 10.min(loss_trajectory.len());
433        let tolerance = self.eval_params.convergence_tolerance;
434
435        if loss_trajectory.len() < window_size {
436            return Ok((false, None));
437        }
438
439        // Check for convergence: small changes in moving average
440        for i in window_size..loss_trajectory.len() {
441            let current_window = &loss_trajectory[i - window_size..i];
442            let prev_window = &loss_trajectory[i - window_size - 1..i - 1];
443
444            let current_avg = current_window.iter().cloned().fold(T::zero(), |a, b| a + b)
445                / scirs2_core::numeric::NumCast::from(window_size as f64)
446                    .unwrap_or_else(|| T::zero());
447            let prev_avg = prev_window.iter().cloned().fold(T::zero(), |a, b| a + b)
448                / scirs2_core::numeric::NumCast::from(window_size as f64)
449                    .unwrap_or_else(|| T::zero());
450
451            let change = (current_avg - prev_avg).abs()
452                / prev_avg
453                    .max(scirs2_core::numeric::NumCast::from(1e-8).unwrap_or_else(|| T::zero()));
454
455            if change < tolerance {
456                return Ok((true, Some(i)));
457            }
458        }
459
460        Ok((false, None))
461    }
462
463    /// Compute efficiency metrics
464    fn compute_efficiency_metrics(
465        &self,
466        wall_time: T,
467        memory_usage: u64,
468        num_steps: usize,
469    ) -> Result<EfficiencyMetrics<T>> {
470        let flops = (num_steps as u64) * 1000; // Simplified FLOP estimation
471        let parameter_efficiency = T::one()
472            / (scirs2_core::numeric::NumCast::from(memory_usage as f64)
473                .unwrap_or_else(|| T::zero())
474                + T::one());
475        let sample_efficiency = scirs2_core::numeric::NumCast::from(num_steps as f64)
476            .unwrap_or_else(|| T::zero())
477            / (wall_time + T::one());
478        let energy_consumption = wall_time
479            * scirs2_core::numeric::NumCast::from(memory_usage as f64).unwrap_or_else(|| T::zero())
480            * scirs2_core::numeric::NumCast::from(1e-9).unwrap_or_else(|| T::zero());
481
482        Ok(EfficiencyMetrics {
483            wall_time,
484            flops,
485            peak_memory: memory_usage,
486            parameter_efficiency,
487            sample_efficiency,
488            energy_consumption,
489        })
490    }
491
492    /// Compute statistical significance
493    fn compute_statistical_significance(
494        &self,
495        metrics: &HashMap<String, T>,
496    ) -> Result<StatisticalSignificance<T>> {
497        // Simplified statistical significance computation
498        // In practice, this would involve proper statistical tests
499
500        let p_value = scirs2_core::numeric::NumCast::from(0.05).unwrap_or_else(|| T::zero()); // Placeholder
501        let effect_size = scirs2_core::numeric::NumCast::from(0.5).unwrap_or_else(|| T::zero()); // Cohen's d
502        let confidence_interval = (
503            scirs2_core::numeric::NumCast::from(0.1).unwrap_or_else(|| T::zero()),
504            scirs2_core::numeric::NumCast::from(0.9).unwrap_or_else(|| T::zero()),
505        );
506        let statistical_power =
507            scirs2_core::numeric::NumCast::from(0.8).unwrap_or_else(|| T::zero());
508        let test_statistic = scirs2_core::numeric::NumCast::from(2.0).unwrap_or_else(|| T::zero());
509
510        Ok(StatisticalSignificance {
511            p_value,
512            effect_size,
513            confidence_interval,
514            statistical_power,
515            test_statistic,
516        })
517    }
518
519    /// Add baseline for comparison
520    pub fn add_baseline(
521        &mut self,
522        baseline_name: String,
523        baseline_performance: HashMap<String, T>,
524    ) -> Result<()> {
525        let comparison = BaselineComparison {
526            baseline_name: baseline_name.clone(),
527            baseline_performance: baseline_performance.clone(),
528            improvement: HashMap::new(),
529            relative_performance: HashMap::new(),
530            win_rate: T::zero(),
531        };
532
533        self.baseline_comparisons.insert(baseline_name, comparison);
534        Ok(())
535    }
536
537    /// Run robustness evaluation
538    pub fn evaluate_robustness(
539        &mut self,
540        task_id: &str,
541        robustness_tests: &RobustnessTestSuite<T>,
542    ) -> Result<HashMap<String, T>> {
543        let mut robustness_scores = HashMap::new();
544
545        // Evaluate noise robustness
546        let mut noise_score = T::zero();
547        for noise_test in &robustness_tests.noise_tests {
548            noise_score = noise_score + (T::one() - noise_test.performance_degradation);
549        }
550        if !robustness_tests.noise_tests.is_empty() {
551            noise_score = noise_score / T::from(robustness_tests.noise_tests.len() as f64).unwrap();
552        }
553        robustness_scores.insert("noise_robustness".to_string(), noise_score);
554
555        // Evaluate adversarial robustness
556        let mut adversarial_score = T::zero();
557        for adv_test in &robustness_tests.adversarial_tests {
558            adversarial_score = adversarial_score + adv_test.robustness_score;
559        }
560        if !robustness_tests.adversarial_tests.is_empty() {
561            adversarial_score = adversarial_score
562                / T::from(robustness_tests.adversarial_tests.len() as f64).unwrap();
563        }
564        robustness_scores.insert("adversarial_robustness".to_string(), adversarial_score);
565
566        // Evaluate hyperparameter sensitivity
567        let mut sensitivity_score = T::zero();
568        for sens_test in &robustness_tests.sensitivity_tests {
569            sensitivity_score =
570                sensitivity_score + (T::one() / (T::one() + sens_test.sensitivity_score));
571        }
572        if !robustness_tests.sensitivity_tests.is_empty() {
573            sensitivity_score = sensitivity_score
574                / T::from(robustness_tests.sensitivity_tests.len() as f64).unwrap();
575        }
576        robustness_scores.insert("hyperparameter_robustness".to_string(), sensitivity_score);
577
578        Ok(robustness_scores)
579    }
580
581    /// Get comprehensive evaluation summary
582    pub fn get_evaluation_summary(&self) -> HashMap<String, T> {
583        let mut summary = HashMap::new();
584
585        // Overall performance statistics
586        if !self.performance_history.is_empty() {
587            // Average final performance
588            let avg_final_loss = self
589                .performance_history
590                .iter()
591                .map(|result| result.convergence_info.final_loss)
592                .fold(T::zero(), |a, b| a + b)
593                / T::from(self.performance_history.len() as f64).unwrap();
594            summary.insert("average_final_loss".to_string(), avg_final_loss);
595
596            // Average convergence rate
597            let avg_convergence_rate = self
598                .performance_history
599                .iter()
600                .map(|result| result.convergence_info.convergence_rate)
601                .fold(T::zero(), |a, b| a + b)
602                / T::from(self.performance_history.len() as f64).unwrap();
603            summary.insert("average_convergence_rate".to_string(), avg_convergence_rate);
604
605            // Success rate (convergence)
606            let success_count = self
607                .performance_history
608                .iter()
609                .filter(|result| result.convergence_info.converged)
610                .count();
611            let success_rate = scirs2_core::numeric::NumCast::from(success_count as f64)
612                .unwrap_or_else(|| T::zero())
613                / T::from(self.performance_history.len() as f64).unwrap();
614            summary.insert("success_rate".to_string(), success_rate);
615        }
616
617        summary.insert(
618            "total_evaluations".to_string(),
619            T::from(self.performance_history.len() as f64).unwrap(),
620        );
621        summary
622    }
623
624    /// Reset evaluator state
625    pub fn reset(&mut self) {
626        self.performance_history.clear();
627        self.baseline_comparisons.clear();
628        self.statistical_analyzers.clear();
629
630        for calculator in self.metric_calculators.values_mut() {
631            calculator.reset();
632        }
633    }
634
635    /// Update evaluation parameters
636    pub fn set_parameters(&mut self, params: EvaluationParams<T>) {
637        self.eval_params = params;
638    }
639}
640
641impl<T: Float + Debug + Default + Clone + Send + Sync + 'static> MetricCalculator<T> {
642    fn new(metric_name: String, aggregation_method: AggregationMethod) -> Result<Self> {
643        Ok(Self {
644            metric_name,
645            calculation_params: HashMap::new(),
646            historical_values: VecDeque::new(),
647            aggregation_method,
648        })
649    }
650
651    fn update(&mut self, value: T) -> Result<()> {
652        self.historical_values.push_back(value);
653        if self.historical_values.len() > 1000 {
654            self.historical_values.pop_front();
655        }
656        Ok(())
657    }
658
659    fn get_aggregated_value(&self) -> Result<T> {
660        if self.historical_values.is_empty() {
661            return Ok(T::zero());
662        }
663
664        match self.aggregation_method {
665            AggregationMethod::Mean => {
666                let sum = self
667                    .historical_values
668                    .iter()
669                    .cloned()
670                    .fold(T::zero(), |a, b| a + b);
671                Ok(sum / T::from(self.historical_values.len() as f64).unwrap())
672            }
673            AggregationMethod::Max => Ok(self
674                .historical_values
675                .iter()
676                .cloned()
677                .fold(T::zero(), |a, b| a.max(b))),
678            AggregationMethod::Min => Ok(self.historical_values.iter().cloned().fold(
679                scirs2_core::numeric::NumCast::from(f64::INFINITY).unwrap_or_else(|| T::zero()),
680                |a, b| a.min(b),
681            )),
682            _ => Ok(self.historical_values.back().copied().unwrap_or(T::zero())),
683        }
684    }
685
686    fn reset(&mut self) {
687        self.historical_values.clear();
688    }
689}
690
691impl<T: Float + Debug + Default + Clone + Send + Sync + 'static> Default for EvaluationParams<T> {
692    fn default() -> Self {
693        Self {
694            num_episodes: 10,
695            eval_frequency: 100,
696            convergence_tolerance: scirs2_core::numeric::NumCast::from(1e-6)
697                .unwrap_or_else(|| T::zero()),
698            max_eval_steps: 10000,
699            confidence_level: scirs2_core::numeric::NumCast::from(0.95)
700                .unwrap_or_else(|| T::zero()),
701            bootstrap_samples: 1000,
702            cv_folds: 5,
703            robustness_severity: scirs2_core::numeric::NumCast::from(0.1)
704                .unwrap_or_else(|| T::zero()),
705        }
706    }
707}