1use optirs_core::error::{OptimError, Result};
8use scirs2_core::ndarray::{Array, Array1, Dimension, ScalarOperand};
9use scirs2_core::numeric::Float;
10use std::collections::VecDeque;
11use std::fmt::Debug;
12
13pub type ObjectiveFunction<A> = Box<dyn Fn(&Array1<A>) -> A>;
15pub type GradientFunction<A> = Box<dyn Fn(&Array1<A>) -> Array1<A>>;
17
18#[derive(Debug)]
20pub struct GradientFlowAnalyzer<A: Float, D: Dimension> {
21 gradient_magnitudes: VecDeque<Vec<A>>,
23 gradient_directions: VecDeque<A>,
25 parameter_updates: VecDeque<Vec<Array<A, D>>>,
27 step_count: usize,
29 _maxhistory: usize,
31 stats_cache: Option<GradientFlowStats<A>>,
33 cache_valid: bool,
35}
36
37impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> GradientFlowAnalyzer<A, D> {
38 pub fn new(_maxhistory: usize) -> Self {
40 Self {
41 gradient_magnitudes: VecDeque::with_capacity(_maxhistory),
42 gradient_directions: VecDeque::with_capacity(_maxhistory),
43 parameter_updates: VecDeque::with_capacity(_maxhistory),
44 step_count: 0,
45 _maxhistory,
46 stats_cache: None,
47 cache_valid: false,
48 }
49 }
50
51 pub fn record_step(
53 &mut self,
54 gradients: &[Array<A, D>],
55 parameter_updates: &[Array<A, D>],
56 ) -> Result<()> {
57 if gradients.len() != parameter_updates.len() {
58 return Err(OptimError::DimensionMismatch(
59 "Number of gradients must match number of parameter _updates".to_string(),
60 ));
61 }
62
63 self.step_count += 1;
64
65 let magnitudes: Vec<A> = gradients
67 .iter()
68 .map(|grad| grad.mapv(|x| x * x).sum().sqrt())
69 .collect();
70
71 self.gradient_magnitudes.push_back(magnitudes);
72
73 if let Some(prev_gradients) = self.parameter_updates.back() {
75 let similarity = self.calculate_cosine_similarity(gradients, prev_gradients)?;
76 self.gradient_directions.push_back(similarity);
77 } else {
78 self.gradient_directions.push_back(A::one());
80 }
81
82 self.parameter_updates.push_back(parameter_updates.to_vec());
84
85 if self.gradient_magnitudes.len() > self._maxhistory {
87 self.gradient_magnitudes.pop_front();
88 }
89 if self.gradient_directions.len() > self._maxhistory {
90 self.gradient_directions.pop_front();
91 }
92 if self.parameter_updates.len() > self._maxhistory {
93 self.parameter_updates.pop_front();
94 }
95
96 self.cache_valid = false;
98
99 Ok(())
100 }
101
102 fn calculate_cosine_similarity(
104 &self,
105 arrays1: &[Array<A, D>],
106 arrays2: &[Array<A, D>],
107 ) -> Result<A> {
108 if arrays1.len() != arrays2.len() {
109 return Err(OptimError::DimensionMismatch(
110 "Array sets must have same length".to_string(),
111 ));
112 }
113
114 let mut dot_product = A::zero();
115 let mut norm1_sq = A::zero();
116 let mut norm2_sq = A::zero();
117
118 for (arr1, arr2) in arrays1.iter().zip(arrays2.iter()) {
119 for (&a, &b) in arr1.iter().zip(arr2.iter()) {
120 dot_product = dot_product + a * b;
121 norm1_sq = norm1_sq + a * a;
122 norm2_sq = norm2_sq + b * b;
123 }
124 }
125
126 let norm1 = norm1_sq.sqrt();
127 let norm2 = norm2_sq.sqrt();
128
129 if norm1 > A::zero() && norm2 > A::zero() {
130 Ok(dot_product / (norm1 * norm2))
131 } else {
132 Ok(A::zero())
133 }
134 }
135
136 pub fn get_stats(&mut self) -> &GradientFlowStats<A> {
138 if !self.cache_valid {
139 self.stats_cache = Some(self.compute_stats());
140 self.cache_valid = true;
141 }
142 self.stats_cache.as_ref().expect("unwrap failed")
143 }
144
145 fn compute_stats(&self) -> GradientFlowStats<A> {
147 let num_param_groups = if let Some(first) = self.gradient_magnitudes.front() {
148 first.len()
149 } else {
150 0
151 };
152
153 let mut per_group_stats = Vec::new();
155 for group_idx in 0..num_param_groups {
156 let group_magnitudes: Vec<A> = self
157 .gradient_magnitudes
158 .iter()
159 .map(|step_mags| step_mags[group_idx])
160 .collect();
161
162 let mean_magnitude = if !group_magnitudes.is_empty() {
163 group_magnitudes.iter().fold(A::zero(), |acc, &x| acc + x)
164 / A::from(group_magnitudes.len()).expect("unwrap failed")
165 } else {
166 A::zero()
167 };
168
169 let variance = if group_magnitudes.len() > 1 {
170 let mean = mean_magnitude;
171 let sum_sq_diff = group_magnitudes
172 .iter()
173 .map(|&x| (x - mean) * (x - mean))
174 .fold(A::zero(), |acc, x| acc + x);
175 sum_sq_diff / A::from(group_magnitudes.len() - 1).expect("unwrap failed")
176 } else {
177 A::zero()
178 };
179
180 let max_magnitude = group_magnitudes
181 .iter()
182 .fold(A::neg_infinity(), |acc, &x| acc.max(x));
183
184 let min_magnitude = group_magnitudes
185 .iter()
186 .fold(A::infinity(), |acc, &x| acc.min(x));
187
188 per_group_stats.push(ParameterGroupStats {
189 mean_magnitude,
190 variance,
191 std_dev: variance.sqrt(),
192 max_magnitude,
193 min_magnitude,
194 magnitude_history: group_magnitudes,
195 });
196 }
197
198 let mean_direction_similarity = if !self.gradient_directions.is_empty() {
200 self.gradient_directions
201 .iter()
202 .fold(A::zero(), |acc, &x| acc + x)
203 / A::from(self.gradient_directions.len()).expect("unwrap failed")
204 } else {
205 A::one()
206 };
207
208 let direction_variance = if self.gradient_directions.len() > 1 {
209 let mean = mean_direction_similarity;
210 let sum_sq_diff = self
211 .gradient_directions
212 .iter()
213 .map(|&x| (x - mean) * (x - mean))
214 .fold(A::zero(), |acc, x| acc + x);
215 sum_sq_diff / A::from(self.gradient_directions.len() - 1).expect("unwrap failed")
216 } else {
217 A::zero()
218 };
219
220 let is_converging = self.analyze_convergence();
222 let oscillation_frequency = self.calculate_oscillation_frequency();
223 let stability_score = self.calculate_stability_score();
224
225 GradientFlowStats {
226 step_count: self.step_count,
227 per_group_stats,
228 mean_direction_similarity,
229 direction_variance,
230 direction_std_dev: direction_variance.sqrt(),
231 is_converging,
232 oscillation_frequency,
233 stability_score,
234 direction_history: self.gradient_directions.iter().copied().collect(),
235 }
236 }
237
238 fn analyze_convergence(&self) -> bool {
240 if self.gradient_magnitudes.len() < 5 {
241 return false;
242 }
243
244 let recent_steps = 5.min(self.gradient_magnitudes.len());
246 let recent_magnitudes: Vec<_> = self
247 .gradient_magnitudes
248 .iter()
249 .rev()
250 .take(recent_steps)
251 .collect();
252
253 let mut converging_groups = 0;
255 let num_groups = recent_magnitudes[0].len();
256
257 for group_idx in 0..num_groups {
258 let group_trend: Vec<A> = recent_magnitudes
259 .iter()
260 .rev()
261 .map(|step| step[group_idx])
262 .collect();
263
264 let is_decreasing = group_trend
266 .windows(2)
267 .map(|window| window[1] < window[0])
268 .filter(|&x| x)
269 .count()
270 >= group_trend.len() / 2;
271
272 if is_decreasing {
273 converging_groups += 1;
274 }
275 }
276
277 converging_groups >= num_groups / 2
278 }
279
280 fn calculate_oscillation_frequency(&self) -> f64 {
282 if self.gradient_directions.len() < 3 {
283 return 0.0;
284 }
285
286 let mut sign_changes = 0;
287 let mut prev_positive = None;
288
289 for &direction in &self.gradient_directions {
290 let is_positive = direction >= A::zero();
291 if let Some(prev) = prev_positive {
292 if prev != is_positive {
293 sign_changes += 1;
294 }
295 }
296 prev_positive = Some(is_positive);
297 }
298
299 sign_changes as f64 / (self.gradient_directions.len() - 1) as f64
300 }
301
302 fn calculate_stability_score(&self) -> f64 {
304 if self.gradient_directions.is_empty() {
305 return 1.0;
306 }
307
308 let direction_consistency = self
310 .gradient_directions
311 .iter()
312 .fold(A::zero(), |acc, &x| acc + x.abs())
313 / A::from(self.gradient_directions.len()).expect("unwrap failed");
314
315 let magnitude_consistency = if !self.gradient_magnitudes.is_empty() {
316 let all_magnitudes: Vec<A> = self
317 .gradient_magnitudes
318 .iter()
319 .flat_map(|step| step.iter())
320 .copied()
321 .collect();
322
323 if all_magnitudes.len() > 1 {
324 let mean = all_magnitudes.iter().fold(A::zero(), |acc, &x| acc + x)
325 / A::from(all_magnitudes.len()).expect("unwrap failed");
326 let variance = all_magnitudes
327 .iter()
328 .map(|&x| (x - mean) * (x - mean))
329 .fold(A::zero(), |acc, x| acc + x)
330 / A::from(all_magnitudes.len()).expect("unwrap failed");
331 let cv = if mean > A::zero() {
332 variance.sqrt() / mean
333 } else {
334 A::zero()
335 };
336 (A::one() / (A::one() + cv)).to_f64().unwrap_or(0.0)
338 } else {
339 1.0
340 }
341 } else {
342 1.0
343 };
344
345 let direction_score = direction_consistency.to_f64().unwrap_or(0.0).abs();
346 (direction_score + magnitude_consistency) / 2.0
347 }
348
349 pub fn step_count(&self) -> usize {
351 self.step_count
352 }
353
354 pub fn clear(&mut self) {
356 self.gradient_magnitudes.clear();
357 self.gradient_directions.clear();
358 self.parameter_updates.clear();
359 self.step_count = 0;
360 self.cache_valid = false;
361 self.stats_cache = None;
362 }
363
364 pub fn export_for_visualization(&self) -> VisualizationData<A> {
366 let magnitude_series: Vec<Vec<A>> = if !self.gradient_magnitudes.is_empty() {
367 let num_groups = self.gradient_magnitudes[0].len();
368 (0..num_groups)
369 .map(|group_idx| {
370 self.gradient_magnitudes
371 .iter()
372 .map(|step| step[group_idx])
373 .collect()
374 })
375 .collect()
376 } else {
377 Vec::new()
378 };
379
380 VisualizationData {
381 step_indices: (0..self.step_count).collect(),
382 magnitude_series,
383 direction_similarities: self.gradient_directions.iter().copied().collect(),
384 }
385 }
386}
387
388#[derive(Debug, Clone)]
390pub struct GradientFlowStats<A: Float> {
391 pub step_count: usize,
393 pub per_group_stats: Vec<ParameterGroupStats<A>>,
395 pub mean_direction_similarity: A,
397 pub direction_variance: A,
399 pub direction_std_dev: A,
401 pub is_converging: bool,
403 pub oscillation_frequency: f64,
405 pub stability_score: f64,
407 pub direction_history: Vec<A>,
409}
410
411#[derive(Debug, Clone)]
413pub struct ParameterGroupStats<A: Float> {
414 pub mean_magnitude: A,
416 pub variance: A,
418 pub std_dev: A,
420 pub max_magnitude: A,
422 pub min_magnitude: A,
424 pub magnitude_history: Vec<A>,
426}
427
428#[derive(Debug, Clone)]
430pub struct VisualizationData<A: Float> {
431 pub step_indices: Vec<usize>,
433 pub magnitude_series: Vec<Vec<A>>,
435 pub direction_similarities: Vec<A>,
437}
438
439pub struct OptimizerBenchmark<A: Float> {
441 test_functions: Vec<TestFunction<A>>,
443 results: Vec<BenchmarkResult<A>>,
445}
446
447impl<A: Float + ScalarOperand + Debug + Send + Sync> OptimizerBenchmark<A> {
448 pub fn new() -> Self {
450 Self {
451 test_functions: Vec::new(),
452 results: Vec::new(),
453 }
454 }
455
456 pub fn add_test_function(&mut self, testfunction: TestFunction<A>) {
458 self.test_functions.push(testfunction);
459 }
460
461 pub fn add_standard_test_functions(&mut self) {
463 self.add_test_function(TestFunction {
465 name: "Quadratic".to_string(),
466 dimension: 10,
467 function: Box::new(|x: &Array1<A>| x.mapv(|val| val * val).sum()),
468 gradient: Box::new(|x: &Array1<A>| {
469 x.mapv(|val| A::from(2.0).expect("unwrap failed") * val)
470 }),
471 optimal_value: Some(A::zero()),
472 optimal_point: Some(Array1::zeros(10)),
473 });
474
475 self.add_test_function(TestFunction {
477 name: "Rosenbrock".to_string(),
478 dimension: 2,
479 function: Box::new(|x: &Array1<A>| {
480 let a = A::one();
481 let b = A::from(100.0).expect("unwrap failed");
482 let term1 = (a - x[0]) * (a - x[0]);
483 let term2 = b * (x[1] - x[0] * x[0]) * (x[1] - x[0] * x[0]);
484 term1 + term2
485 }),
486 gradient: Box::new(|x: &Array1<A>| {
487 let a = A::one();
488 let b = A::from(100.0).expect("unwrap failed");
489 let grad_x = A::from(-2.0).expect("unwrap failed") * (a - x[0])
490 - A::from(4.0).expect("unwrap failed") * b * x[0] * (x[1] - x[0] * x[0]);
491 let grad_y = A::from(2.0).expect("unwrap failed") * b * (x[1] - x[0] * x[0]);
492 Array1::from_vec(vec![grad_x, grad_y])
493 }),
494 optimal_value: Some(A::zero()),
495 optimal_point: Some(Array1::from_vec(vec![A::one(), A::one()])),
496 });
497
498 self.add_test_function(TestFunction {
500 name: "Sphere".to_string(),
501 dimension: 5,
502 function: Box::new(|x: &Array1<A>| x.mapv(|val| val * val).sum()),
503 gradient: Box::new(|x: &Array1<A>| {
504 x.mapv(|val| A::from(2.0).expect("unwrap failed") * val)
505 }),
506 optimal_value: Some(A::zero()),
507 optimal_point: Some(Array1::zeros(5)),
508 });
509 }
510
511 pub fn run_benchmark<F>(
513 &mut self,
514 optimizername: String,
515 mut optimization_step: F,
516 max_iterations: usize,
517 tolerance: A,
518 ) -> Result<Vec<BenchmarkResult<A>>>
519 where
520 F: FnMut(&Array1<A>, &Array1<A>) -> Array1<A>,
521 {
522 let mut results = Vec::new();
523
524 for testfunction in &self.test_functions {
525 let mut x = Array1::from_vec(
526 (0..testfunction.dimension)
527 .map(|_| A::from(0.5).expect("unwrap failed"))
528 .collect(),
529 );
530
531 let mut function_values = Vec::new();
532 let mut gradient_norms = Vec::new();
533 let mut convergence_step = None;
534
535 let start_time = std::time::Instant::now();
536
537 for iteration in 0..max_iterations {
538 let f_val = (testfunction.function)(&x);
539 let grad = (testfunction.gradient)(&x);
540 let grad_norm = grad.mapv(|g| g * g).sum().sqrt();
541
542 function_values.push(f_val);
543 gradient_norms.push(grad_norm);
544
545 if grad_norm < tolerance {
547 convergence_step = Some(iteration);
548 break;
549 }
550
551 x = optimization_step(&x, &grad);
553 }
554
555 let elapsed = start_time.elapsed();
556
557 let final_error = if let Some(optimal_value) = testfunction.optimal_value {
558 (function_values.last().copied().expect("unwrap failed") - optimal_value).abs()
559 } else {
560 A::zero()
561 };
562
563 let result = BenchmarkResult {
564 optimizername: optimizername.clone(),
565 function_name: testfunction.name.clone(),
566 converged: convergence_step.is_some(),
567 convergence_step,
568 final_function_value: *function_values.last().expect("unwrap failed"),
569 final_gradient_norm: *gradient_norms.last().expect("unwrap failed"),
570 final_error,
571 iterations_taken: function_values.len(),
572 elapsed_time: elapsed,
573 function_evaluations: function_values.len(),
574 function_value_history: function_values,
575 gradient_norm_history: gradient_norms,
576 };
577
578 results.push(result.clone());
579 }
580
581 self.results.extend(results.clone());
582 Ok(results)
583 }
584
585 pub fn get_results(&self) -> &[BenchmarkResult<A>] {
587 &self.results
588 }
589
590 pub fn clear_results(&mut self) {
592 self.results.clear();
593 }
594
595 pub fn generate_report(&self) -> BenchmarkReport<A> {
597 let mut optimizer_performance = std::collections::HashMap::new();
598
599 for result in &self.results {
600 let entry = optimizer_performance
601 .entry(result.optimizername.clone())
602 .or_insert_with(|| OptimizerPerformance {
603 total_runs: 0,
604 successful_runs: 0,
605 average_iterations: 0.0,
606 average_final_error: A::zero(),
607 average_time: std::time::Duration::from_secs(0),
608 });
609
610 entry.total_runs += 1;
611 if result.converged {
612 entry.successful_runs += 1;
613 }
614 entry.average_iterations += result.iterations_taken as f64;
615 entry.average_final_error = entry.average_final_error + result.final_error;
616 entry.average_time += result.elapsed_time;
617 }
618
619 for performance in optimizer_performance.values_mut() {
621 if performance.total_runs > 0 {
622 performance.average_iterations /= performance.total_runs as f64;
623 performance.average_final_error = performance.average_final_error
624 / A::from(performance.total_runs).expect("unwrap failed");
625 performance.average_time /= performance.total_runs as u32;
626 }
627 }
628
629 BenchmarkReport {
630 total_tests: self.results.len(),
631 optimizer_performance,
632 }
633 }
634}
635
636pub struct TestFunction<A: Float> {
638 pub name: String,
640 pub dimension: usize,
642 pub function: ObjectiveFunction<A>,
644 pub gradient: GradientFunction<A>,
646 pub optimal_value: Option<A>,
648 pub optimal_point: Option<Array1<A>>,
650}
651
652#[derive(Debug, Clone)]
654pub struct BenchmarkResult<A: Float> {
655 pub optimizername: String,
657 pub function_name: String,
659 pub converged: bool,
661 pub convergence_step: Option<usize>,
663 pub final_function_value: A,
665 pub final_gradient_norm: A,
667 pub final_error: A,
669 pub iterations_taken: usize,
671 pub elapsed_time: std::time::Duration,
673 pub function_evaluations: usize,
675 pub function_value_history: Vec<A>,
677 pub gradient_norm_history: Vec<A>,
679}
680
681#[derive(Debug, Clone)]
683pub struct OptimizerPerformance<A: Float> {
684 pub total_runs: usize,
686 pub successful_runs: usize,
688 pub average_iterations: f64,
690 pub average_final_error: A,
692 pub average_time: std::time::Duration,
694}
695
696#[derive(Debug)]
698pub struct BenchmarkReport<A: Float> {
699 pub total_tests: usize,
701 pub optimizer_performance: std::collections::HashMap<String, OptimizerPerformance<A>>,
703}
704
705impl<A: Float + Send + Sync> BenchmarkReport<A> {
706 pub fn get_success_rate(&self, optimizername: &str) -> Option<f64> {
708 self.optimizer_performance.get(optimizername).map(|perf| {
709 if perf.total_runs > 0 {
710 perf.successful_runs as f64 / perf.total_runs as f64
711 } else {
712 0.0
713 }
714 })
715 }
716
717 pub fn compare_optimizers(&self, opt1: &str, opt2: &str) -> Option<OptimizerComparison<A>> {
719 let perf1 = self.optimizer_performance.get(opt1)?;
720 let perf2 = self.optimizer_performance.get(opt2)?;
721
722 Some(OptimizerComparison {
723 optimizer1: opt1.to_string(),
724 optimizer2: opt2.to_string(),
725 success_rate_diff: self.get_success_rate(opt1).unwrap_or(0.0)
726 - self.get_success_rate(opt2).unwrap_or(0.0),
727 avg_iterations_diff: perf1.average_iterations - perf2.average_iterations,
728 avg_error_diff: perf1.average_final_error - perf2.average_final_error,
729 })
730 }
731}
732
733#[derive(Debug, Clone)]
735pub struct OptimizerComparison<A: Float> {
736 pub optimizer1: String,
738 pub optimizer2: String,
740 pub success_rate_diff: f64,
742 pub avg_iterations_diff: f64,
744 pub avg_error_diff: A,
746}
747
748impl<A: Float + ScalarOperand + Debug + Send + Sync> Default for OptimizerBenchmark<A> {
749 fn default() -> Self {
750 Self::new()
751 }
752}
753
754pub mod visualization {
756 use super::*;
757 use std::fmt::Write;
758
759 #[derive(Debug)]
761 pub struct OptimizerStateVisualizer<A: Float, D: Dimension> {
762 parameter_history: VecDeque<Vec<Array<A, D>>>,
764 state_history: VecDeque<OptimizerStateSnapshot<A>>,
766 learning_rate_history: VecDeque<A>,
768 loss_history: VecDeque<A>,
770 _maxhistory: usize,
772 step_count: usize,
774 }
775
776 impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> OptimizerStateVisualizer<A, D> {
777 pub fn new(_maxhistory: usize) -> Self {
779 Self {
780 parameter_history: VecDeque::with_capacity(_maxhistory),
781 state_history: VecDeque::with_capacity(_maxhistory),
782 learning_rate_history: VecDeque::with_capacity(_maxhistory),
783 loss_history: VecDeque::with_capacity(_maxhistory),
784 _maxhistory,
785 step_count: 0,
786 }
787 }
788
789 pub fn record_step(
791 &mut self,
792 parameters: &[Array<A, D>],
793 state_snapshot: OptimizerStateSnapshot<A>,
794 learning_rate: A,
795 loss_value: A,
796 ) {
797 self.step_count += 1;
798
799 self.parameter_history.push_back(parameters.to_vec());
801 if self.parameter_history.len() > self._maxhistory {
802 self.parameter_history.pop_front();
803 }
804
805 self.state_history.push_back(state_snapshot);
807 if self.state_history.len() > self._maxhistory {
808 self.state_history.pop_front();
809 }
810
811 self.learning_rate_history.push_back(learning_rate);
813 if self.learning_rate_history.len() > self._maxhistory {
814 self.learning_rate_history.pop_front();
815 }
816
817 self.loss_history.push_back(loss_value);
819 if self.loss_history.len() > self._maxhistory {
820 self.loss_history.pop_front();
821 }
822 }
823
824 pub fn generate_convergence_plot(&self, width: usize, height: usize) -> String {
826 if self.loss_history.is_empty() {
827 return "No data to visualize".to_string();
828 }
829
830 let mut plot = String::new();
831
832 let min_loss = self
834 .loss_history
835 .iter()
836 .fold(A::infinity(), |acc, &x| acc.min(x));
837 let max_loss = self
838 .loss_history
839 .iter()
840 .fold(A::neg_infinity(), |acc, &x| acc.max(x));
841
842 let loss_range = max_loss - min_loss;
843
844 writeln!(plot, "Loss Convergence (Steps: {})", self.step_count).expect("unwrap failed");
845 writeln!(
846 plot,
847 "Max: {:.6}, Min: {:.6}",
848 max_loss.to_f64().unwrap_or(0.0),
849 min_loss.to_f64().unwrap_or(0.0)
850 )
851 .expect("unwrap failed");
852 writeln!(plot, "{}", "=".repeat(width + 10)).expect("unwrap failed");
853
854 for row in 0..height {
856 let y_value = max_loss
857 - (A::from(row).expect("unwrap failed")
858 / A::from(height - 1).expect("unwrap failed"))
859 * loss_range;
860 write!(plot, "{:8.3} |", y_value.to_f64().unwrap_or(0.0)).expect("unwrap failed");
861
862 for col in 0..width {
863 let step_index = (col * self.loss_history.len()) / width;
864 if step_index < self.loss_history.len() {
865 let loss_val = self.loss_history[step_index];
866 let normalized_y = ((max_loss - loss_val) / loss_range
867 * A::from(height - 1).expect("unwrap failed"))
868 .to_usize()
869 .unwrap_or(0);
870
871 if normalized_y == row {
872 write!(plot, "*").expect("unwrap failed");
873 } else {
874 write!(plot, " ").expect("unwrap failed");
875 }
876 } else {
877 write!(plot, " ").expect("unwrap failed");
878 }
879 }
880 writeln!(plot, "|").expect("unwrap failed");
881 }
882
883 writeln!(plot, " {}", "-".repeat(width)).expect("unwrap failed");
884 writeln!(
885 plot,
886 " 0{:width$}Steps",
887 self.step_count,
888 width = width - 10
889 )
890 .expect("unwrap failed");
891
892 plot
893 }
894
895 pub fn generate_learning_rate_plot(&self, width: usize, height: usize) -> String {
897 if self.learning_rate_history.is_empty() {
898 return "No learning rate data to visualize".to_string();
899 }
900
901 let mut plot = String::new();
902
903 let min_lr = self
904 .learning_rate_history
905 .iter()
906 .fold(A::infinity(), |acc, &x| acc.min(x));
907 let max_lr = self
908 .learning_rate_history
909 .iter()
910 .fold(A::neg_infinity(), |acc, &x| acc.max(x));
911
912 let lr_range = max_lr - min_lr;
913
914 writeln!(plot, "Learning Rate Schedule").expect("unwrap failed");
915 writeln!(
916 plot,
917 "Max: {:.6}, Min: {:.6}",
918 max_lr.to_f64().unwrap_or(0.0),
919 min_lr.to_f64().unwrap_or(0.0)
920 )
921 .expect("unwrap failed");
922 writeln!(plot, "{}", "=".repeat(width + 10)).expect("unwrap failed");
923
924 for row in 0..height {
925 let y_value = max_lr
926 - (A::from(row).expect("unwrap failed")
927 / A::from(height - 1).expect("unwrap failed"))
928 * lr_range;
929 write!(plot, "{:8.3} |", y_value.to_f64().unwrap_or(0.0)).expect("unwrap failed");
930
931 for col in 0..width {
932 let step_index = (col * self.learning_rate_history.len()) / width;
933 if step_index < self.learning_rate_history.len() {
934 let lr_val = self.learning_rate_history[step_index];
935 let normalized_y = if lr_range > A::zero() {
936 ((max_lr - lr_val) / lr_range
937 * A::from(height - 1).expect("unwrap failed"))
938 .to_usize()
939 .unwrap_or(0)
940 } else {
941 height / 2
942 };
943
944 if normalized_y == row {
945 write!(plot, "*").expect("unwrap failed");
946 } else {
947 write!(plot, " ").expect("unwrap failed");
948 }
949 } else {
950 write!(plot, " ").expect("unwrap failed");
951 }
952 }
953 writeln!(plot, "|").expect("unwrap failed");
954 }
955
956 writeln!(plot, " {}", "-".repeat(width)).expect("unwrap failed");
957 writeln!(
958 plot,
959 " 0{:width$}Steps",
960 self.step_count,
961 width = width - 10
962 )
963 .expect("unwrap failed");
964
965 plot
966 }
967
968 pub fn generate_parameter_heatmap(&self, width: usize, height: usize) -> String {
970 if self.parameter_history.is_empty() {
971 return "No parameter data to visualize".to_string();
972 }
973
974 let mut plot = String::new();
975 writeln!(plot, "Parameter Evolution Heatmap").expect("unwrap failed");
976 writeln!(plot, "{}", "=".repeat(width + 5)).expect("unwrap failed");
977
978 let all_params: Vec<A> = self
980 .parameter_history
981 .iter()
982 .flat_map(|step| step.iter().flat_map(|array| array.iter().copied()))
983 .collect();
984
985 if all_params.is_empty() {
986 return "No parameter data available".to_string();
987 }
988
989 let min_param = all_params.iter().fold(A::infinity(), |acc, &x| acc.min(x));
990 let max_param = all_params
991 .iter()
992 .fold(A::neg_infinity(), |acc, &x| acc.max(x));
993 let param_range = max_param - min_param;
994
995 let num_steps = self.parameter_history.len().min(width);
997 let num_params = if !self.parameter_history.is_empty() {
998 self.parameter_history[0]
999 .iter()
1000 .map(|arr| arr.len())
1001 .sum::<usize>()
1002 .min(height)
1003 } else {
1004 0
1005 };
1006
1007 for param_idx in 0..num_params {
1008 write!(plot, "P{:3} |", param_idx).expect("unwrap failed");
1009
1010 for step_idx in 0..num_steps {
1011 let step_data = &self.parameter_history[step_idx];
1012
1013 let mut flat_idx = 0;
1015 let mut found_value = None;
1016
1017 for array in step_data {
1018 if flat_idx + array.len() > param_idx {
1019 let local_idx = param_idx - flat_idx;
1020 if let Some(&value) = array.iter().nth(local_idx) {
1021 found_value = Some(value);
1022 break;
1023 }
1024 }
1025 flat_idx += array.len();
1026 }
1027
1028 if let Some(value) = found_value {
1029 let normalized = if param_range > A::zero() {
1030 ((value - min_param) / param_range).to_f64().unwrap_or(0.0)
1031 } else {
1032 0.5
1033 };
1034
1035 let char = match (normalized * 4.0) as i32 {
1036 0 => ' ',
1037 1 => '.',
1038 2 => ':',
1039 3 => '*',
1040 _ => '#',
1041 };
1042 write!(plot, "{}", char).expect("unwrap failed");
1043 } else {
1044 write!(plot, " ").expect("unwrap failed");
1045 }
1046 }
1047 writeln!(plot, "|").expect("unwrap failed");
1048 }
1049
1050 writeln!(plot, " {}", "-".repeat(num_steps)).expect("unwrap failed");
1051 writeln!(plot, " Legend: ' ' = Low, '.' < ':' < '*' < '#' = High")
1052 .expect("unwrap failed");
1053 writeln!(
1054 plot,
1055 " Range: {:.6} to {:.6}",
1056 min_param.to_f64().unwrap_or(0.0),
1057 max_param.to_f64().unwrap_or(0.0)
1058 )
1059 .expect("unwrap failed");
1060
1061 plot
1062 }
1063
1064 pub fn generate_state_summary(&self) -> String {
1066 let mut summary = String::new();
1067
1068 writeln!(summary, "Optimizer State Summary").expect("unwrap failed");
1069 writeln!(summary, "======================").expect("unwrap failed");
1070 writeln!(summary, "Total Steps: {}", self.step_count).expect("unwrap failed");
1071 writeln!(summary, "History Length: {}", self.parameter_history.len())
1072 .expect("unwrap failed");
1073
1074 if let Some(current_loss) = self.loss_history.back() {
1075 writeln!(
1076 summary,
1077 "Current Loss: {:.6}",
1078 current_loss.to_f64().unwrap_or(0.0)
1079 )
1080 .expect("unwrap failed");
1081 }
1082
1083 if let Some(current_lr) = self.learning_rate_history.back() {
1084 writeln!(
1085 summary,
1086 "Current Learning Rate: {:.6}",
1087 current_lr.to_f64().unwrap_or(0.0)
1088 )
1089 .expect("unwrap failed");
1090 }
1091
1092 if !self.loss_history.is_empty() {
1094 let min_loss = self
1095 .loss_history
1096 .iter()
1097 .fold(A::infinity(), |acc, &x| acc.min(x));
1098 let max_loss = self
1099 .loss_history
1100 .iter()
1101 .fold(A::neg_infinity(), |acc, &x| acc.max(x));
1102 let avg_loss = self.loss_history.iter().fold(A::zero(), |acc, &x| acc + x)
1103 / A::from(self.loss_history.len()).expect("unwrap failed");
1104
1105 writeln!(summary, "\nLoss Statistics:").expect("unwrap failed");
1106 writeln!(summary, " Min: {:.6}", min_loss.to_f64().unwrap_or(0.0))
1107 .expect("unwrap failed");
1108 writeln!(summary, " Max: {:.6}", max_loss.to_f64().unwrap_or(0.0))
1109 .expect("unwrap failed");
1110 writeln!(summary, " Avg: {:.6}", avg_loss.to_f64().unwrap_or(0.0))
1111 .expect("unwrap failed");
1112
1113 if self.loss_history.len() > 1 {
1115 let first_loss = self.loss_history[0];
1116 let last_loss = *self.loss_history.back().expect("unwrap failed");
1117 let improvement = first_loss - last_loss;
1118 let improvement_rate = improvement / first_loss;
1119 writeln!(
1120 summary,
1121 " Improvement: {:.6} ({:.2}%)",
1122 improvement.to_f64().unwrap_or(0.0),
1123 (improvement_rate.to_f64().unwrap_or(0.0) * 100.0)
1124 )
1125 .expect("unwrap failed");
1126 }
1127 }
1128
1129 if !self.parameter_history.is_empty() {
1131 let current_params = self.parameter_history.back().expect("unwrap failed");
1132 let total_params: usize = current_params.iter().map(|arr| arr.len()).sum();
1133 writeln!(summary, "\nParameter Statistics:").expect("unwrap failed");
1134 writeln!(summary, " Total Parameters: {}", total_params).expect("unwrap failed");
1135 writeln!(summary, " Parameter Groups: {}", current_params.len())
1136 .expect("unwrap failed");
1137
1138 for (i, array) in current_params.iter().enumerate() {
1140 let l2_norm = array.mapv(|x| x * x).sum().sqrt();
1141 writeln!(
1142 summary,
1143 " Group {} L2 Norm: {:.6}",
1144 i,
1145 l2_norm.to_f64().unwrap_or(0.0)
1146 )
1147 .expect("unwrap failed");
1148 }
1149 }
1150
1151 if !self.state_history.is_empty() {
1153 writeln!(summary, "\nOptimizer State:").expect("unwrap failed");
1154 if let Some(latest_state) = self.state_history.back() {
1155 writeln!(
1156 summary,
1157 " Momentum Norm: {:.6}",
1158 latest_state.momentum_norm.to_f64().unwrap_or(0.0)
1159 )
1160 .expect("unwrap failed");
1161 writeln!(
1162 summary,
1163 " Velocity Norm: {:.6}",
1164 latest_state.velocity_norm.to_f64().unwrap_or(0.0)
1165 )
1166 .expect("unwrap failed");
1167 writeln!(
1168 summary,
1169 " Step Size: {:.6}",
1170 latest_state.effective_step_size.to_f64().unwrap_or(0.0)
1171 )
1172 .expect("unwrap failed");
1173 writeln!(
1174 summary,
1175 " Beta1: {:.6}",
1176 latest_state.beta1.to_f64().unwrap_or(0.0)
1177 )
1178 .expect("unwrap failed");
1179 writeln!(
1180 summary,
1181 " Beta2: {:.6}",
1182 latest_state.beta2.to_f64().unwrap_or(0.0)
1183 )
1184 .expect("unwrap failed");
1185 }
1186 }
1187
1188 summary
1189 }
1190
1191 pub fn export_data(&self) -> VisualizationExport<A> {
1193 VisualizationExport {
1194 step_indices: (0..self.step_count).collect(),
1195 loss_history: self.loss_history.iter().copied().collect(),
1196 learning_rate_history: self.learning_rate_history.iter().copied().collect(),
1197 parameter_norms: self
1198 .parameter_history
1199 .iter()
1200 .map(|step| {
1201 step.iter()
1202 .map(|array| array.mapv(|x| x * x).sum().sqrt())
1203 .collect()
1204 })
1205 .collect(),
1206 state_snapshots: self.state_history.iter().cloned().collect(),
1207 }
1208 }
1209
1210 pub fn clear(&mut self) {
1212 self.parameter_history.clear();
1213 self.state_history.clear();
1214 self.learning_rate_history.clear();
1215 self.loss_history.clear();
1216 self.step_count = 0;
1217 }
1218
1219 pub fn step_count(&self) -> usize {
1221 self.step_count
1222 }
1223 }
1224
1225 #[derive(Debug, Clone)]
1227 pub struct OptimizerStateSnapshot<A: Float> {
1228 pub momentum_norm: A,
1230 pub velocity_norm: A,
1232 pub effective_step_size: A,
1234 pub beta1: A,
1236 pub beta2: A,
1238 pub custom_fields: std::collections::HashMap<String, A>,
1240 }
1241
1242 impl<A: Float + Send + Sync> OptimizerStateSnapshot<A> {
1243 pub fn new() -> Self {
1245 Self {
1246 momentum_norm: A::zero(),
1247 velocity_norm: A::zero(),
1248 effective_step_size: A::zero(),
1249 beta1: A::zero(),
1250 beta2: A::zero(),
1251 custom_fields: std::collections::HashMap::new(),
1252 }
1253 }
1254
1255 pub fn with_custom_field(mut self, name: String, value: A) -> Self {
1257 self.custom_fields.insert(name, value);
1258 self
1259 }
1260 }
1261
1262 impl<A: Float + Send + Sync> Default for OptimizerStateSnapshot<A> {
1263 fn default() -> Self {
1264 Self::new()
1265 }
1266 }
1267
1268 #[derive(Debug, Clone)]
1270 pub struct VisualizationExport<A: Float> {
1271 pub step_indices: Vec<usize>,
1273 pub loss_history: Vec<A>,
1275 pub learning_rate_history: Vec<A>,
1277 pub parameter_norms: Vec<Vec<A>>,
1279 pub state_snapshots: Vec<OptimizerStateSnapshot<A>>,
1281 }
1282
1283 #[derive(Debug)]
1285 pub struct OptimizerDashboard<A: Float, D: Dimension> {
1286 visualizers: std::collections::HashMap<String, OptimizerStateVisualizer<A, D>>,
1288 #[allow(dead_code)]
1290 comparison_metrics: Vec<ComparisonMetric<A>>,
1291 }
1292
1293 impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> OptimizerDashboard<A, D> {
1294 pub fn new() -> Self {
1296 Self {
1297 visualizers: std::collections::HashMap::new(),
1298 comparison_metrics: Vec::new(),
1299 }
1300 }
1301
1302 pub fn add_optimizer(&mut self, name: String, maxhistory: usize) {
1304 self.visualizers
1305 .insert(name, OptimizerStateVisualizer::new(maxhistory));
1306 }
1307
1308 pub fn record_optimizer_step(
1310 &mut self,
1311 optimizername: &str,
1312 parameters: &[Array<A, D>],
1313 state_snapshot: OptimizerStateSnapshot<A>,
1314 learning_rate: A,
1315 loss_value: A,
1316 ) -> Result<()> {
1317 if let Some(visualizer) = self.visualizers.get_mut(optimizername) {
1318 visualizer.record_step(parameters, state_snapshot, learning_rate, loss_value);
1319 Ok(())
1320 } else {
1321 Err(OptimError::InvalidConfig(format!(
1322 "Optimizer '{}' not found in dashboard",
1323 optimizername
1324 )))
1325 }
1326 }
1327
1328 pub fn generate_comparison_report(&self) -> String {
1330 let mut report = String::new();
1331
1332 writeln!(report, "Optimizer Comparison Dashboard").expect("unwrap failed");
1333 writeln!(report, "===============================").expect("unwrap failed");
1334
1335 for (name, visualizer) in &self.visualizers {
1336 writeln!(report, "\n{}", name).expect("unwrap failed");
1337 writeln!(report, "{}", "-".repeat(name.len())).expect("unwrap failed");
1338
1339 if let Some(current_loss) = visualizer.loss_history.back() {
1340 writeln!(
1341 report,
1342 "Current Loss: {:.6}",
1343 current_loss.to_f64().unwrap_or(0.0)
1344 )
1345 .expect("unwrap failed");
1346 }
1347
1348 writeln!(report, "Steps: {}", visualizer.step_count).expect("unwrap failed");
1349
1350 if visualizer.loss_history.len() > 1 {
1352 let first_loss = visualizer.loss_history[0];
1353 let last_loss = *visualizer.loss_history.back().expect("unwrap failed");
1354 let improvement = first_loss - last_loss;
1355 writeln!(
1356 report,
1357 "Total Improvement: {:.6}",
1358 improvement.to_f64().unwrap_or(0.0)
1359 )
1360 .expect("unwrap failed");
1361 }
1362 }
1363
1364 if !self.visualizers.is_empty() {
1366 writeln!(report, "\nBest Performers:").expect("unwrap failed");
1367 writeln!(report, "================").expect("unwrap failed");
1368
1369 let best_current_loss = self
1370 .visualizers
1371 .iter()
1372 .filter_map(|(name, viz)| viz.loss_history.back().map(|&loss| (name, loss)))
1373 .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
1374
1375 if let Some((best_name, best_loss)) = best_current_loss {
1376 writeln!(
1377 report,
1378 "Lowest Current Loss: {} ({:.6})",
1379 best_name,
1380 best_loss.to_f64().unwrap_or(0.0)
1381 )
1382 .expect("unwrap failed");
1383 }
1384 }
1385
1386 report
1387 }
1388
1389 pub fn get_visualizer(
1391 &self,
1392 optimizername: &str,
1393 ) -> Option<&OptimizerStateVisualizer<A, D>> {
1394 self.visualizers.get(optimizername)
1395 }
1396
1397 pub fn get_visualizer_mut(
1399 &mut self,
1400 optimizername: &str,
1401 ) -> Option<&mut OptimizerStateVisualizer<A, D>> {
1402 self.visualizers.get_mut(optimizername)
1403 }
1404
1405 pub fn list_optimizers(&self) -> Vec<&String> {
1407 self.visualizers.keys().collect()
1408 }
1409 }
1410
1411 impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> Default
1412 for OptimizerDashboard<A, D>
1413 {
1414 fn default() -> Self {
1415 Self::new()
1416 }
1417 }
1418
1419 #[derive(Debug, Clone)]
1421 pub struct ComparisonMetric<A: Float> {
1422 pub name: String,
1424 pub values: std::collections::HashMap<String, A>,
1426 }
1427}
1428
1429#[cfg(test)]
1430mod tests {
1431 use super::*;
1432 use approx::assert_relative_eq;
1433
1434 #[test]
1435 fn test_gradient_flow_analyzer() {
1436 let mut analyzer = GradientFlowAnalyzer::new(100);
1437
1438 let gradients1 = vec![Array1::from_vec(vec![1.0, 2.0])];
1439 let updates1 = vec![Array1::from_vec(vec![0.1, 0.2])];
1440
1441 let gradients2 = vec![Array1::from_vec(vec![0.8, 1.6])];
1442 let updates2 = vec![Array1::from_vec(vec![0.08, 0.16])];
1443
1444 analyzer
1445 .record_step(&gradients1, &updates1)
1446 .expect("unwrap failed");
1447 analyzer
1448 .record_step(&gradients2, &updates2)
1449 .expect("unwrap failed");
1450
1451 assert_eq!(analyzer.step_count(), 2);
1452
1453 let stats = analyzer.get_stats();
1454 assert_eq!(stats.step_count, 2);
1455 assert_eq!(stats.per_group_stats.len(), 1);
1456
1457 let expected_mag1 = (1.0_f64 * 1.0 + 2.0 * 2.0).sqrt();
1459 let expected_mag2 = (0.8_f64 * 0.8 + 1.6 * 1.6).sqrt();
1460
1461 assert_relative_eq!(
1462 stats.per_group_stats[0].magnitude_history[0],
1463 expected_mag1,
1464 epsilon = 1e-6
1465 );
1466 assert_relative_eq!(
1467 stats.per_group_stats[0].magnitude_history[1],
1468 expected_mag2,
1469 epsilon = 1e-6
1470 );
1471
1472 assert!(stats.mean_direction_similarity > 0.9);
1474 }
1475
1476 #[test]
1477 #[ignore = "timeout"]
1478 fn test_benchmark_quadratic() {
1479 let mut benchmark = OptimizerBenchmark::new();
1480 benchmark.add_standard_test_functions();
1481
1482 let learning_rate = 0.01;
1484 let mut step_function = |x: &Array1<f64>, grad: &Array1<f64>| x - &(grad * learning_rate);
1485
1486 let results = benchmark
1487 .run_benchmark(
1488 "GradientDescent".to_string(),
1489 &mut step_function,
1490 1000,
1491 1e-6,
1492 )
1493 .expect("unwrap failed");
1494
1495 assert!(!results.is_empty());
1496
1497 let quadratic_result = results
1499 .iter()
1500 .find(|r| r.function_name == "Quadratic")
1501 .expect("unwrap failed");
1502
1503 assert!(quadratic_result.converged);
1504 assert!(quadratic_result.final_function_value < 1e-3);
1505 }
1506
1507 #[test]
1508 fn test_cosine_similarity() {
1509 let analyzer = GradientFlowAnalyzer::<f64, scirs2_core::ndarray::Ix1>::new(10);
1510
1511 let arrays1 = vec![Array1::from_vec(vec![1.0, 0.0])];
1512 let arrays2 = vec![Array1::from_vec(vec![1.0, 0.0])]; let similarity = analyzer
1514 .calculate_cosine_similarity(&arrays1, &arrays2)
1515 .expect("unwrap failed");
1516 assert_relative_eq!(similarity, 1.0, epsilon = 1e-6);
1517
1518 let arrays3 = vec![Array1::from_vec(vec![-1.0, 0.0])]; let similarity2 = analyzer
1520 .calculate_cosine_similarity(&arrays1, &arrays3)
1521 .expect("unwrap failed");
1522 assert_relative_eq!(similarity2, -1.0, epsilon = 1e-6);
1523
1524 let arrays4 = vec![Array1::from_vec(vec![0.0, 1.0])]; let similarity3 = analyzer
1526 .calculate_cosine_similarity(&arrays1, &arrays4)
1527 .expect("unwrap failed");
1528 assert_relative_eq!(similarity3, 0.0, epsilon = 1e-6);
1529 }
1530
1531 #[test]
1532 #[ignore = "timeout"]
1533 fn test_benchmark_report() {
1534 let mut benchmark = OptimizerBenchmark::new();
1535 benchmark.add_test_function(TestFunction {
1536 name: "Simple".to_string(),
1537 dimension: 2,
1538 function: Box::new(|x: &Array1<f64>| x[0] * x[0] + x[1] * x[1]),
1539 gradient: Box::new(|x: &Array1<f64>| Array1::from_vec(vec![2.0 * x[0], 2.0 * x[1]])),
1540 optimal_value: Some(0.0),
1541 optimal_point: Some(Array1::zeros(2)),
1542 });
1543
1544 let mut step1 = |x: &Array1<f64>, grad: &Array1<f64>| x - &(grad * 0.1);
1546 let mut step2 = |x: &Array1<f64>, grad: &Array1<f64>| x - &(grad * 0.05);
1547
1548 benchmark
1549 .run_benchmark("Fast".to_string(), &mut step1, 100, 1e-3)
1550 .expect("unwrap failed");
1551 benchmark
1552 .run_benchmark("Slow".to_string(), &mut step2, 100, 1e-3)
1553 .expect("unwrap failed");
1554
1555 let report = benchmark.generate_report();
1556 assert_eq!(report.total_tests, 2);
1557 assert!(report.optimizer_performance.contains_key("Fast"));
1558 assert!(report.optimizer_performance.contains_key("Slow"));
1559
1560 let comparison = report
1561 .compare_optimizers("Fast", "Slow")
1562 .expect("unwrap failed");
1563 assert_eq!(comparison.optimizer1, "Fast");
1564 assert_eq!(comparison.optimizer2, "Slow");
1565 }
1566
1567 #[test]
1568 fn test_visualization_data_export() {
1569 let mut analyzer = GradientFlowAnalyzer::new(10);
1570
1571 let _gradients = [Array1::from_vec(vec![1.0, 2.0])];
1572 let _updates = [Array1::from_vec(vec![0.1, 0.2])];
1573
1574 for i in 0..5 {
1575 let scale = 1.0 / (i + 1) as f64;
1576 let scaled_grad = vec![Array1::from_vec(vec![scale, 2.0 * scale])];
1577 let scaled_update = vec![Array1::from_vec(vec![0.1 * scale, 0.2 * scale])];
1578 analyzer
1579 .record_step(&scaled_grad, &scaled_update)
1580 .expect("unwrap failed");
1581 }
1582
1583 let viz_data = analyzer.export_for_visualization();
1584 assert_eq!(viz_data.step_indices.len(), 5);
1585 assert_eq!(viz_data.magnitude_series.len(), 1); assert_eq!(viz_data.magnitude_series[0].len(), 5); assert_eq!(viz_data.direction_similarities.len(), 5); let magnitudes = &viz_data.magnitude_series[0];
1591 for i in 1..magnitudes.len() {
1592 assert!(magnitudes[i] < magnitudes[i - 1]);
1593 }
1594 }
1595
1596 #[test]
1597 fn test_convergence_analysis() {
1598 let mut analyzer = GradientFlowAnalyzer::new(10);
1599
1600 for i in 0..10 {
1602 let scale = 1.0 / (i + 1) as f64;
1603 let gradients = vec![Array1::from_vec(vec![scale, scale])];
1604 let updates = vec![Array1::from_vec(vec![0.1 * scale, 0.1 * scale])];
1605 analyzer
1606 .record_step(&gradients, &updates)
1607 .expect("unwrap failed");
1608 }
1609
1610 let stats = analyzer.get_stats();
1611 assert!(stats.is_converging);
1612 assert!(stats.stability_score > 0.5);
1613 }
1614
1615 #[test]
1616 fn test_oscillation_detection() {
1617 let mut analyzer = GradientFlowAnalyzer::new(10);
1618
1619 let gradients = vec![Array1::from_vec(vec![1.0, 1.0])];
1621 let updates = vec![Array1::from_vec(vec![0.1, 0.1])];
1622 analyzer
1623 .record_step(&gradients, &updates)
1624 .expect("unwrap failed");
1625
1626 for i in 1..8 {
1628 let sign = if i % 2 == 0 { 1.0 } else { -1.0 };
1629 let gradients = vec![Array1::from_vec(vec![sign, sign])];
1630 let updates = vec![Array1::from_vec(vec![0.1 * sign, 0.1 * sign])];
1631 analyzer
1632 .record_step(&gradients, &updates)
1633 .expect("unwrap failed");
1634 }
1635
1636 let stats = analyzer.get_stats();
1637 assert!(stats.oscillation_frequency >= 0.0); }
1642
1643 #[test]
1644 fn test_optimizer_state_visualizer() {
1645 let mut visualizer = visualization::OptimizerStateVisualizer::new(100);
1646
1647 let params = vec![Array1::from_vec(vec![1.0, 2.0])];
1648 let state = visualization::OptimizerStateSnapshot::new()
1649 .with_custom_field("test_field".to_string(), 0.5);
1650
1651 visualizer.record_step(¶ms, state, 0.01, 1.5);
1652 visualizer.record_step(
1653 ¶ms,
1654 visualization::OptimizerStateSnapshot::new(),
1655 0.009,
1656 1.2,
1657 );
1658
1659 assert_eq!(visualizer.step_count(), 2);
1660
1661 let summary = visualizer.generate_state_summary();
1663 assert!(summary.contains("Total Steps: 2"));
1664 assert!(summary.contains("Current Loss: 1.200000"));
1665
1666 let plot = visualizer.generate_convergence_plot(40, 10);
1668 assert!(plot.contains("Loss Convergence"));
1669 assert!(plot.contains("Steps: 2"));
1670
1671 let lr_plot = visualizer.generate_learning_rate_plot(40, 10);
1673 assert!(lr_plot.contains("Learning Rate Schedule"));
1674
1675 let heatmap = visualizer.generate_parameter_heatmap(20, 5);
1677 assert!(heatmap.contains("Parameter Evolution Heatmap"));
1678 }
1679
1680 #[test]
1681 fn test_visualization_export() {
1682 let mut visualizer = visualization::OptimizerStateVisualizer::new(10);
1683
1684 for i in 0..5 {
1685 let params = vec![Array1::from_vec(vec![i as f64, (i * 2) as f64])];
1686 let state = visualization::OptimizerStateSnapshot::new();
1687 let lr = 0.01 / (i + 1) as f64;
1688 let loss = 1.0 / (i + 1) as f64;
1689
1690 visualizer.record_step(¶ms, state, lr, loss);
1691 }
1692
1693 let export = visualizer.export_data();
1694 assert_eq!(export.step_indices.len(), 5);
1695 assert_eq!(export.loss_history.len(), 5);
1696 assert_eq!(export.learning_rate_history.len(), 5);
1697 assert_eq!(export.parameter_norms.len(), 5);
1698 assert_eq!(export.state_snapshots.len(), 5);
1699
1700 assert!(export.loss_history[0] > export.loss_history[4]);
1702 assert!(export.learning_rate_history[0] > export.learning_rate_history[4]);
1703 }
1704
1705 #[test]
1706 fn test_optimizer_dashboard() {
1707 let mut dashboard = visualization::OptimizerDashboard::new();
1708
1709 dashboard.add_optimizer("SGD".to_string(), 100);
1710 dashboard.add_optimizer("Adam".to_string(), 100);
1711
1712 let params = vec![Array1::from_vec(vec![1.0, 2.0])];
1713 let state = visualization::OptimizerStateSnapshot::new();
1714
1715 dashboard
1717 .record_optimizer_step("SGD", ¶ms, state.clone(), 0.01, 1.0)
1718 .expect("unwrap failed");
1719 dashboard
1720 .record_optimizer_step("Adam", ¶ms, state, 0.001, 0.8)
1721 .expect("unwrap failed");
1722
1723 let optimizers = dashboard.list_optimizers();
1724 assert_eq!(optimizers.len(), 2);
1725 assert!(optimizers.contains(&&"SGD".to_string()));
1726 assert!(optimizers.contains(&&"Adam".to_string()));
1727
1728 let sgd_viz = dashboard.get_visualizer("SGD").expect("unwrap failed");
1730 assert_eq!(sgd_viz.step_count(), 1);
1731
1732 let report = dashboard.generate_comparison_report();
1734 assert!(report.contains("Optimizer Comparison Dashboard"));
1735 assert!(report.contains("SGD"));
1736 assert!(report.contains("Adam"));
1737 assert!(report.contains("Lowest Current Loss: Adam"));
1738 }
1739
1740 #[test]
1741 fn test_state_snapshot_custom_fields() {
1742 let snapshot = visualization::OptimizerStateSnapshot::new()
1743 .with_custom_field("custom1".to_string(), 1.5)
1744 .with_custom_field("custom2".to_string(), 2.5);
1745
1746 assert_eq!(snapshot.custom_fields.len(), 2);
1747 assert_eq!(snapshot.custom_fields.get("custom1"), Some(&1.5));
1748 assert_eq!(snapshot.custom_fields.get("custom2"), Some(&2.5));
1749 }
1750
1751 #[test]
1752 fn test_visualizer_clear() {
1753 let mut visualizer = visualization::OptimizerStateVisualizer::new(10);
1754
1755 let params = vec![Array1::from_vec(vec![1.0])];
1756 let state = visualization::OptimizerStateSnapshot::new();
1757
1758 visualizer.record_step(¶ms, state, 0.01, 1.0);
1759 assert_eq!(visualizer.step_count(), 1);
1760
1761 visualizer.clear();
1762 assert_eq!(visualizer.step_count(), 0);
1763
1764 let summary = visualizer.generate_state_summary();
1765 assert!(summary.contains("Total Steps: 0"));
1766 }
1767
1768 #[test]
1769 fn test_dashboard_invalid_optimizer() {
1770 let mut dashboard = visualization::OptimizerDashboard::new();
1771
1772 let params = vec![Array1::from_vec(vec![1.0])];
1773 let state = visualization::OptimizerStateSnapshot::new();
1774
1775 let result = dashboard.record_optimizer_step("NonExistent", ¶ms, state, 0.01, 1.0);
1777 assert!(result.is_err());
1778 }
1779
1780 #[test]
1781 fn test_ascii_plot_generation() {
1782 let mut visualizer = visualization::OptimizerStateVisualizer::new(10);
1783
1784 for i in 0..10 {
1786 let params = vec![Array1::from_vec(vec![1.0])];
1787 let state = visualization::OptimizerStateSnapshot::new();
1788 let loss = 10.0 - i as f64; let lr = 0.1; visualizer.record_step(¶ms, state, lr, loss);
1792 }
1793
1794 let plot = visualizer.generate_convergence_plot(20, 5);
1796 let lines: Vec<&str> = plot.lines().collect();
1797 assert!(lines.len() > 5); assert!(plot.contains("|")); assert!(plot.contains("-")); assert!(plot.contains("*")); let lr_plot = visualizer.generate_learning_rate_plot(20, 5);
1806 assert!(lr_plot.contains("Learning Rate Schedule"));
1807 assert!(lr_plot.contains("Max: 0.100000, Min: 0.100000")); }
1809
1810 #[test]
1811 fn test_parameter_heatmap_generation() {
1812 let mut visualizer = visualization::OptimizerStateVisualizer::new(10);
1813
1814 for i in 0..5 {
1816 let params = vec![Array1::from_vec(vec![i as f64 * 0.1, i as f64 * 0.2])];
1817 let state = visualization::OptimizerStateSnapshot::new();
1818 visualizer.record_step(¶ms, state, 0.01, 1.0);
1819 }
1820
1821 let heatmap = visualizer.generate_parameter_heatmap(10, 5);
1822 assert!(heatmap.contains("Parameter Evolution Heatmap"));
1823 assert!(heatmap.contains("Legend"));
1824 assert!(heatmap.contains("Range"));
1825
1826 assert!(heatmap.contains("P"));
1828 }
1829}