1use std::fmt::Debug;
2#[allow(dead_code)]
8use scirs2_core::ndarray::{Array1, Array2, Array3};
9use scirs2_core::numeric::Float;
10use std::collections::{HashMap, VecDeque};
11
12use crate::error::{OptimError, Result};
13
14#[derive(Debug, Clone, Copy)]
16pub enum EvaluationStrategy {
17 SingleTask,
19 MultiTask,
21 CrossDomain,
23 FewShot,
25 ContinualLearning,
27 Robustness,
29 Efficiency,
31 Comprehensive,
33}
34
35#[derive(Debug, Clone)]
37pub struct TransformerEvaluator<T: Float + Debug + Send + Sync + 'static> {
38 strategy: EvaluationStrategy,
40
41 eval_params: EvaluationParams<T>,
43
44 metric_calculators: HashMap<String, MetricCalculator<T>>,
46
47 performance_history: VecDeque<EvaluationResult<T>>,
49
50 baseline_comparisons: HashMap<String, BaselineComparison<T>>,
52
53 statistical_analyzers: Vec<StatisticalAnalyzer<T>>,
55}
56
57#[derive(Debug, Clone)]
59pub struct EvaluationParams<T: Float + Debug + Send + Sync + 'static> {
60 num_episodes: usize,
62
63 eval_frequency: usize,
65
66 convergence_tolerance: T,
68
69 max_eval_steps: usize,
71
72 confidence_level: T,
74
75 bootstrap_samples: usize,
77
78 cv_folds: usize,
80
81 robustness_severity: T,
83}
84
85#[derive(Debug, Clone)]
87pub struct EvaluationResult<T: Float + Debug + Send + Sync + 'static> {
88 eval_id: String,
90
91 task_id: String,
93
94 metrics: HashMap<String, T>,
96
97 convergence_info: ConvergenceInfo<T>,
99
100 efficiency_metrics: EfficiencyMetrics<T>,
102
103 statistical_significance: StatisticalSignificance<T>,
105
106 timestamp: usize,
108}
109
110#[derive(Debug, Clone)]
112pub struct ConvergenceInfo<T: Float + Debug + Send + Sync + 'static> {
113 converged: bool,
115
116 steps_to_convergence: Option<usize>,
118
119 final_loss: T,
121
122 convergence_rate: T,
124
125 loss_trajectory: Vec<T>,
127
128 gradient_norms: Vec<T>,
130}
131
132#[derive(Debug, Clone)]
134pub struct EfficiencyMetrics<T: Float + Debug + Send + Sync + 'static> {
135 wall_time: T,
137
138 flops: u64,
140
141 peak_memory: u64,
143
144 parameter_efficiency: T,
146
147 sample_efficiency: T,
149
150 energy_consumption: T,
152}
153
154#[derive(Debug, Clone)]
156pub struct StatisticalSignificance<T: Float + Debug + Send + Sync + 'static> {
157 p_value: T,
159
160 effect_size: T,
162
163 confidence_interval: (T, T),
165
166 statistical_power: T,
168
169 test_statistic: T,
171}
172
173#[derive(Debug, Clone)]
175pub struct MetricCalculator<T: Float + Debug + Send + Sync + 'static> {
176 metric_name: String,
178
179 calculation_params: HashMap<String, T>,
181
182 historical_values: VecDeque<T>,
184
185 aggregation_method: AggregationMethod,
187}
188
189#[derive(Debug, Clone)]
191pub struct BaselineComparison<T: Float + Debug + Send + Sync + 'static> {
192 baseline_name: String,
194
195 baseline_performance: HashMap<String, T>,
197
198 improvement: HashMap<String, T>,
200
201 relative_performance: HashMap<String, T>,
203
204 win_rate: T,
206}
207
208#[derive(Debug, Clone)]
210pub struct StatisticalAnalyzer<T: Float + Debug + Send + Sync + 'static> {
211 analyzer_name: String,
213
214 analysis_params: HashMap<String, T>,
216
217 results_cache: HashMap<String, T>,
219}
220
221#[derive(Debug, Clone)]
223pub struct RobustnessTestSuite<T: Float + Debug + Send + Sync + 'static> {
224 noise_tests: Vec<NoiseTest<T>>,
226
227 adversarial_tests: Vec<AdversarialTest<T>>,
229
230 sensitivity_tests: Vec<SensitivityTest<T>>,
232
233 distribution_tests: Vec<DistributionTest<T>>,
235}
236
237#[derive(Debug, Clone)]
239pub struct NoiseTest<T: Float + Debug + Send + Sync + 'static> {
240 noise_type: NoiseType,
241 noise_level: T,
242 performance_degradation: T,
243}
244
245#[derive(Debug, Clone)]
246pub struct AdversarialTest<T: Float + Debug + Send + Sync + 'static> {
247 attack_type: AttackType,
248 attack_strength: T,
249 robustness_score: T,
250}
251
252#[derive(Debug, Clone)]
253pub struct SensitivityTest<T: Float + Debug + Send + Sync + 'static> {
254 parameter_name: String,
255 parameter_range: (T, T),
256 sensitivity_score: T,
257}
258
259#[derive(Debug, Clone)]
260pub struct DistributionTest<T: Float + Debug + Send + Sync + 'static> {
261 shift_type: DistributionShiftType,
262 shift_magnitude: T,
263 adaptation_score: T,
264}
265
266#[derive(Debug, Clone, Copy)]
268pub enum AggregationMethod {
269 Mean,
270 Median,
271 Max,
272 Min,
273 WeightedAverage,
274 ExponentialMovingAverage,
275 Percentile(u8),
276}
277
278#[derive(Debug, Clone, Copy)]
280pub enum NoiseType {
281 Gaussian,
282 Uniform,
283 SaltPepper,
284 Dropout,
285}
286
287#[derive(Debug, Clone, Copy)]
289pub enum AttackType {
290 FGSM,
291 PGD,
292 CarliniWagner,
293 DeepFool,
294}
295
296#[derive(Debug, Clone, Copy)]
298pub enum DistributionShiftType {
299 CovariateShift,
300 ConceptDrift,
301 DatasetShift,
302 TemporalShift,
303}
304
305impl<T: Float + Debug + Default + Clone + Send + Sync + 'static> TransformerEvaluator<T> {
306 pub fn new(strategy: EvaluationStrategy) -> Result<Self> {
308 let mut metric_calculators = HashMap::new();
309
310 metric_calculators.insert(
312 "convergence_speed".to_string(),
313 MetricCalculator::new("convergence_speed".to_string(), AggregationMethod::Mean)?,
314 );
315 metric_calculators.insert(
316 "final_performance".to_string(),
317 MetricCalculator::new("final_performance".to_string(), AggregationMethod::Mean)?,
318 );
319 metric_calculators.insert(
320 "sample_efficiency".to_string(),
321 MetricCalculator::new("sample_efficiency".to_string(), AggregationMethod::Mean)?,
322 );
323
324 Ok(Self {
325 strategy,
326 eval_params: EvaluationParams::default(),
327 metric_calculators,
328 performance_history: VecDeque::new(),
329 baseline_comparisons: HashMap::new(),
330 statistical_analyzers: Vec::new(),
331 })
332 }
333
334 pub fn evaluate(
336 &mut self,
337 task_id: &str,
338 loss_trajectory: &[T],
339 gradient_norms: &[T],
340 wall_time: T,
341 memory_usage: u64,
342 ) -> Result<EvaluationResult<T>> {
343 let eval_id = format!("eval_{}_{}", task_id, self.performance_history.len());
344
345 let convergence_info = self.compute_convergence_info(loss_trajectory, gradient_norms)?;
347
348 let efficiency_metrics =
350 self.compute_efficiency_metrics(wall_time, memory_usage, loss_trajectory.len())?;
351
352 let mut metrics = HashMap::new();
354 metrics.insert("final_loss".to_string(), convergence_info.final_loss);
355 metrics.insert(
356 "convergence_rate".to_string(),
357 convergence_info.convergence_rate,
358 );
359 metrics.insert(
360 "sample_efficiency".to_string(),
361 efficiency_metrics.sample_efficiency,
362 );
363
364 for (metric_name, metric_value) in &metrics {
366 if let Some(calculator) = self.metric_calculators.get_mut(metric_name) {
367 calculator.update(*metric_value)?;
368 }
369 }
370
371 let statistical_significance = self.compute_statistical_significance(&metrics)?;
373
374 let result = EvaluationResult {
375 eval_id,
376 task_id: task_id.to_string(),
377 metrics,
378 convergence_info,
379 efficiency_metrics,
380 statistical_significance,
381 timestamp: self.performance_history.len(),
382 };
383
384 self.performance_history.push_back(result.clone());
385 if self.performance_history.len() > 1000 {
386 self.performance_history.pop_front();
387 }
388
389 Ok(result)
390 }
391
392 fn compute_convergence_info(
394 &self,
395 loss_trajectory: &[T],
396 gradient_norms: &[T],
397 ) -> Result<ConvergenceInfo<T>> {
398 if loss_trajectory.is_empty() {
399 return Err(OptimError::InvalidConfig(
400 "Empty loss trajectory".to_string(),
401 ));
402 }
403
404 let final_loss = *loss_trajectory.last().unwrap();
405 let initial_loss = loss_trajectory[0];
406
407 let (converged, steps_to_convergence) = self.detect_convergence(loss_trajectory)?;
409
410 let convergence_rate = if loss_trajectory.len() > 1 {
412 let improvement = (initial_loss - final_loss)
413 / initial_loss
414 .max(scirs2_core::numeric::NumCast::from(1e-8).unwrap_or_else(|| T::zero()));
415 improvement / T::from(loss_trajectory.len() as f64).unwrap()
416 } else {
417 T::zero()
418 };
419
420 Ok(ConvergenceInfo {
421 converged,
422 steps_to_convergence,
423 final_loss,
424 convergence_rate,
425 loss_trajectory: loss_trajectory.to_vec(),
426 gradient_norms: gradient_norms.to_vec(),
427 })
428 }
429
430 fn detect_convergence(&self, loss_trajectory: &[T]) -> Result<(bool, Option<usize>)> {
432 let window_size = 10.min(loss_trajectory.len());
433 let tolerance = self.eval_params.convergence_tolerance;
434
435 if loss_trajectory.len() < window_size {
436 return Ok((false, None));
437 }
438
439 for i in window_size..loss_trajectory.len() {
441 let current_window = &loss_trajectory[i - window_size..i];
442 let prev_window = &loss_trajectory[i - window_size - 1..i - 1];
443
444 let current_avg = current_window.iter().cloned().fold(T::zero(), |a, b| a + b)
445 / scirs2_core::numeric::NumCast::from(window_size as f64)
446 .unwrap_or_else(|| T::zero());
447 let prev_avg = prev_window.iter().cloned().fold(T::zero(), |a, b| a + b)
448 / scirs2_core::numeric::NumCast::from(window_size as f64)
449 .unwrap_or_else(|| T::zero());
450
451 let change = (current_avg - prev_avg).abs()
452 / prev_avg
453 .max(scirs2_core::numeric::NumCast::from(1e-8).unwrap_or_else(|| T::zero()));
454
455 if change < tolerance {
456 return Ok((true, Some(i)));
457 }
458 }
459
460 Ok((false, None))
461 }
462
463 fn compute_efficiency_metrics(
465 &self,
466 wall_time: T,
467 memory_usage: u64,
468 num_steps: usize,
469 ) -> Result<EfficiencyMetrics<T>> {
470 let flops = (num_steps as u64) * 1000; let parameter_efficiency = T::one()
472 / (scirs2_core::numeric::NumCast::from(memory_usage as f64)
473 .unwrap_or_else(|| T::zero())
474 + T::one());
475 let sample_efficiency = scirs2_core::numeric::NumCast::from(num_steps as f64)
476 .unwrap_or_else(|| T::zero())
477 / (wall_time + T::one());
478 let energy_consumption = wall_time
479 * scirs2_core::numeric::NumCast::from(memory_usage as f64).unwrap_or_else(|| T::zero())
480 * scirs2_core::numeric::NumCast::from(1e-9).unwrap_or_else(|| T::zero());
481
482 Ok(EfficiencyMetrics {
483 wall_time,
484 flops,
485 peak_memory: memory_usage,
486 parameter_efficiency,
487 sample_efficiency,
488 energy_consumption,
489 })
490 }
491
492 fn compute_statistical_significance(
494 &self,
495 metrics: &HashMap<String, T>,
496 ) -> Result<StatisticalSignificance<T>> {
497 let p_value = scirs2_core::numeric::NumCast::from(0.05).unwrap_or_else(|| T::zero()); let effect_size = scirs2_core::numeric::NumCast::from(0.5).unwrap_or_else(|| T::zero()); let confidence_interval = (
503 scirs2_core::numeric::NumCast::from(0.1).unwrap_or_else(|| T::zero()),
504 scirs2_core::numeric::NumCast::from(0.9).unwrap_or_else(|| T::zero()),
505 );
506 let statistical_power =
507 scirs2_core::numeric::NumCast::from(0.8).unwrap_or_else(|| T::zero());
508 let test_statistic = scirs2_core::numeric::NumCast::from(2.0).unwrap_or_else(|| T::zero());
509
510 Ok(StatisticalSignificance {
511 p_value,
512 effect_size,
513 confidence_interval,
514 statistical_power,
515 test_statistic,
516 })
517 }
518
519 pub fn add_baseline(
521 &mut self,
522 baseline_name: String,
523 baseline_performance: HashMap<String, T>,
524 ) -> Result<()> {
525 let comparison = BaselineComparison {
526 baseline_name: baseline_name.clone(),
527 baseline_performance: baseline_performance.clone(),
528 improvement: HashMap::new(),
529 relative_performance: HashMap::new(),
530 win_rate: T::zero(),
531 };
532
533 self.baseline_comparisons.insert(baseline_name, comparison);
534 Ok(())
535 }
536
537 pub fn evaluate_robustness(
539 &mut self,
540 task_id: &str,
541 robustness_tests: &RobustnessTestSuite<T>,
542 ) -> Result<HashMap<String, T>> {
543 let mut robustness_scores = HashMap::new();
544
545 let mut noise_score = T::zero();
547 for noise_test in &robustness_tests.noise_tests {
548 noise_score = noise_score + (T::one() - noise_test.performance_degradation);
549 }
550 if !robustness_tests.noise_tests.is_empty() {
551 noise_score = noise_score / T::from(robustness_tests.noise_tests.len() as f64).unwrap();
552 }
553 robustness_scores.insert("noise_robustness".to_string(), noise_score);
554
555 let mut adversarial_score = T::zero();
557 for adv_test in &robustness_tests.adversarial_tests {
558 adversarial_score = adversarial_score + adv_test.robustness_score;
559 }
560 if !robustness_tests.adversarial_tests.is_empty() {
561 adversarial_score = adversarial_score
562 / T::from(robustness_tests.adversarial_tests.len() as f64).unwrap();
563 }
564 robustness_scores.insert("adversarial_robustness".to_string(), adversarial_score);
565
566 let mut sensitivity_score = T::zero();
568 for sens_test in &robustness_tests.sensitivity_tests {
569 sensitivity_score =
570 sensitivity_score + (T::one() / (T::one() + sens_test.sensitivity_score));
571 }
572 if !robustness_tests.sensitivity_tests.is_empty() {
573 sensitivity_score = sensitivity_score
574 / T::from(robustness_tests.sensitivity_tests.len() as f64).unwrap();
575 }
576 robustness_scores.insert("hyperparameter_robustness".to_string(), sensitivity_score);
577
578 Ok(robustness_scores)
579 }
580
581 pub fn get_evaluation_summary(&self) -> HashMap<String, T> {
583 let mut summary = HashMap::new();
584
585 if !self.performance_history.is_empty() {
587 let avg_final_loss = self
589 .performance_history
590 .iter()
591 .map(|result| result.convergence_info.final_loss)
592 .fold(T::zero(), |a, b| a + b)
593 / T::from(self.performance_history.len() as f64).unwrap();
594 summary.insert("average_final_loss".to_string(), avg_final_loss);
595
596 let avg_convergence_rate = self
598 .performance_history
599 .iter()
600 .map(|result| result.convergence_info.convergence_rate)
601 .fold(T::zero(), |a, b| a + b)
602 / T::from(self.performance_history.len() as f64).unwrap();
603 summary.insert("average_convergence_rate".to_string(), avg_convergence_rate);
604
605 let success_count = self
607 .performance_history
608 .iter()
609 .filter(|result| result.convergence_info.converged)
610 .count();
611 let success_rate = scirs2_core::numeric::NumCast::from(success_count as f64)
612 .unwrap_or_else(|| T::zero())
613 / T::from(self.performance_history.len() as f64).unwrap();
614 summary.insert("success_rate".to_string(), success_rate);
615 }
616
617 summary.insert(
618 "total_evaluations".to_string(),
619 T::from(self.performance_history.len() as f64).unwrap(),
620 );
621 summary
622 }
623
624 pub fn reset(&mut self) {
626 self.performance_history.clear();
627 self.baseline_comparisons.clear();
628 self.statistical_analyzers.clear();
629
630 for calculator in self.metric_calculators.values_mut() {
631 calculator.reset();
632 }
633 }
634
635 pub fn set_parameters(&mut self, params: EvaluationParams<T>) {
637 self.eval_params = params;
638 }
639}
640
641impl<T: Float + Debug + Default + Clone + Send + Sync + 'static> MetricCalculator<T> {
642 fn new(metric_name: String, aggregation_method: AggregationMethod) -> Result<Self> {
643 Ok(Self {
644 metric_name,
645 calculation_params: HashMap::new(),
646 historical_values: VecDeque::new(),
647 aggregation_method,
648 })
649 }
650
651 fn update(&mut self, value: T) -> Result<()> {
652 self.historical_values.push_back(value);
653 if self.historical_values.len() > 1000 {
654 self.historical_values.pop_front();
655 }
656 Ok(())
657 }
658
659 fn get_aggregated_value(&self) -> Result<T> {
660 if self.historical_values.is_empty() {
661 return Ok(T::zero());
662 }
663
664 match self.aggregation_method {
665 AggregationMethod::Mean => {
666 let sum = self
667 .historical_values
668 .iter()
669 .cloned()
670 .fold(T::zero(), |a, b| a + b);
671 Ok(sum / T::from(self.historical_values.len() as f64).unwrap())
672 }
673 AggregationMethod::Max => Ok(self
674 .historical_values
675 .iter()
676 .cloned()
677 .fold(T::zero(), |a, b| a.max(b))),
678 AggregationMethod::Min => Ok(self.historical_values.iter().cloned().fold(
679 scirs2_core::numeric::NumCast::from(f64::INFINITY).unwrap_or_else(|| T::zero()),
680 |a, b| a.min(b),
681 )),
682 _ => Ok(self.historical_values.back().copied().unwrap_or(T::zero())),
683 }
684 }
685
686 fn reset(&mut self) {
687 self.historical_values.clear();
688 }
689}
690
691impl<T: Float + Debug + Default + Clone + Send + Sync + 'static> Default for EvaluationParams<T> {
692 fn default() -> Self {
693 Self {
694 num_episodes: 10,
695 eval_frequency: 100,
696 convergence_tolerance: scirs2_core::numeric::NumCast::from(1e-6)
697 .unwrap_or_else(|| T::zero()),
698 max_eval_steps: 10000,
699 confidence_level: scirs2_core::numeric::NumCast::from(0.95)
700 .unwrap_or_else(|| T::zero()),
701 bootstrap_samples: 1000,
702 cv_folds: 5,
703 robustness_severity: scirs2_core::numeric::NumCast::from(0.1)
704 .unwrap_or_else(|| T::zero()),
705 }
706 }
707}