1use optirs_core::error::{OptimError, Result};
8use scirs2_core::ndarray::{Array, Array1, Dimension, ScalarOperand};
9use scirs2_core::numeric::Float;
10use std::collections::VecDeque;
11use std::fmt::Debug;
12
13pub type ObjectiveFunction<A> = Box<dyn Fn(&Array1<A>) -> A>;
15pub type GradientFunction<A> = Box<dyn Fn(&Array1<A>) -> Array1<A>>;
17
18#[derive(Debug)]
20pub struct GradientFlowAnalyzer<A: Float, D: Dimension> {
21 gradient_magnitudes: VecDeque<Vec<A>>,
23 gradient_directions: VecDeque<A>,
25 parameter_updates: VecDeque<Vec<Array<A, D>>>,
27 step_count: usize,
29 _maxhistory: usize,
31 stats_cache: Option<GradientFlowStats<A>>,
33 cache_valid: bool,
35}
36
37impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> GradientFlowAnalyzer<A, D> {
38 pub fn new(_maxhistory: usize) -> Self {
40 Self {
41 gradient_magnitudes: VecDeque::with_capacity(_maxhistory),
42 gradient_directions: VecDeque::with_capacity(_maxhistory),
43 parameter_updates: VecDeque::with_capacity(_maxhistory),
44 step_count: 0,
45 _maxhistory,
46 stats_cache: None,
47 cache_valid: false,
48 }
49 }
50
51 pub fn record_step(
53 &mut self,
54 gradients: &[Array<A, D>],
55 parameter_updates: &[Array<A, D>],
56 ) -> Result<()> {
57 if gradients.len() != parameter_updates.len() {
58 return Err(OptimError::DimensionMismatch(
59 "Number of gradients must match number of parameter _updates".to_string(),
60 ));
61 }
62
63 self.step_count += 1;
64
65 let magnitudes: Vec<A> = gradients
67 .iter()
68 .map(|grad| grad.mapv(|x| x * x).sum().sqrt())
69 .collect();
70
71 self.gradient_magnitudes.push_back(magnitudes);
72
73 if let Some(prev_gradients) = self.parameter_updates.back() {
75 let similarity = self.calculate_cosine_similarity(gradients, prev_gradients)?;
76 self.gradient_directions.push_back(similarity);
77 } else {
78 self.gradient_directions.push_back(A::one());
80 }
81
82 self.parameter_updates.push_back(parameter_updates.to_vec());
84
85 if self.gradient_magnitudes.len() > self._maxhistory {
87 self.gradient_magnitudes.pop_front();
88 }
89 if self.gradient_directions.len() > self._maxhistory {
90 self.gradient_directions.pop_front();
91 }
92 if self.parameter_updates.len() > self._maxhistory {
93 self.parameter_updates.pop_front();
94 }
95
96 self.cache_valid = false;
98
99 Ok(())
100 }
101
102 fn calculate_cosine_similarity(
104 &self,
105 arrays1: &[Array<A, D>],
106 arrays2: &[Array<A, D>],
107 ) -> Result<A> {
108 if arrays1.len() != arrays2.len() {
109 return Err(OptimError::DimensionMismatch(
110 "Array sets must have same length".to_string(),
111 ));
112 }
113
114 let mut dot_product = A::zero();
115 let mut norm1_sq = A::zero();
116 let mut norm2_sq = A::zero();
117
118 for (arr1, arr2) in arrays1.iter().zip(arrays2.iter()) {
119 for (&a, &b) in arr1.iter().zip(arr2.iter()) {
120 dot_product = dot_product + a * b;
121 norm1_sq = norm1_sq + a * a;
122 norm2_sq = norm2_sq + b * b;
123 }
124 }
125
126 let norm1 = norm1_sq.sqrt();
127 let norm2 = norm2_sq.sqrt();
128
129 if norm1 > A::zero() && norm2 > A::zero() {
130 Ok(dot_product / (norm1 * norm2))
131 } else {
132 Ok(A::zero())
133 }
134 }
135
136 pub fn get_stats(&mut self) -> &GradientFlowStats<A> {
138 if !self.cache_valid {
139 self.stats_cache = Some(self.compute_stats());
140 self.cache_valid = true;
141 }
142 self.stats_cache.as_ref().unwrap()
143 }
144
145 fn compute_stats(&self) -> GradientFlowStats<A> {
147 let num_param_groups = if let Some(first) = self.gradient_magnitudes.front() {
148 first.len()
149 } else {
150 0
151 };
152
153 let mut per_group_stats = Vec::new();
155 for group_idx in 0..num_param_groups {
156 let group_magnitudes: Vec<A> = self
157 .gradient_magnitudes
158 .iter()
159 .map(|step_mags| step_mags[group_idx])
160 .collect();
161
162 let mean_magnitude = if !group_magnitudes.is_empty() {
163 group_magnitudes.iter().fold(A::zero(), |acc, &x| acc + x)
164 / A::from(group_magnitudes.len()).unwrap()
165 } else {
166 A::zero()
167 };
168
169 let variance = if group_magnitudes.len() > 1 {
170 let mean = mean_magnitude;
171 let sum_sq_diff = group_magnitudes
172 .iter()
173 .map(|&x| (x - mean) * (x - mean))
174 .fold(A::zero(), |acc, x| acc + x);
175 sum_sq_diff / A::from(group_magnitudes.len() - 1).unwrap()
176 } else {
177 A::zero()
178 };
179
180 let max_magnitude = group_magnitudes
181 .iter()
182 .fold(A::neg_infinity(), |acc, &x| acc.max(x));
183
184 let min_magnitude = group_magnitudes
185 .iter()
186 .fold(A::infinity(), |acc, &x| acc.min(x));
187
188 per_group_stats.push(ParameterGroupStats {
189 mean_magnitude,
190 variance,
191 std_dev: variance.sqrt(),
192 max_magnitude,
193 min_magnitude,
194 magnitude_history: group_magnitudes,
195 });
196 }
197
198 let mean_direction_similarity = if !self.gradient_directions.is_empty() {
200 self.gradient_directions
201 .iter()
202 .fold(A::zero(), |acc, &x| acc + x)
203 / A::from(self.gradient_directions.len()).unwrap()
204 } else {
205 A::one()
206 };
207
208 let direction_variance = if self.gradient_directions.len() > 1 {
209 let mean = mean_direction_similarity;
210 let sum_sq_diff = self
211 .gradient_directions
212 .iter()
213 .map(|&x| (x - mean) * (x - mean))
214 .fold(A::zero(), |acc, x| acc + x);
215 sum_sq_diff / A::from(self.gradient_directions.len() - 1).unwrap()
216 } else {
217 A::zero()
218 };
219
220 let is_converging = self.analyze_convergence();
222 let oscillation_frequency = self.calculate_oscillation_frequency();
223 let stability_score = self.calculate_stability_score();
224
225 GradientFlowStats {
226 step_count: self.step_count,
227 per_group_stats,
228 mean_direction_similarity,
229 direction_variance,
230 direction_std_dev: direction_variance.sqrt(),
231 is_converging,
232 oscillation_frequency,
233 stability_score,
234 direction_history: self.gradient_directions.iter().copied().collect(),
235 }
236 }
237
238 fn analyze_convergence(&self) -> bool {
240 if self.gradient_magnitudes.len() < 5 {
241 return false;
242 }
243
244 let recent_steps = 5.min(self.gradient_magnitudes.len());
246 let recent_magnitudes: Vec<_> = self
247 .gradient_magnitudes
248 .iter()
249 .rev()
250 .take(recent_steps)
251 .collect();
252
253 let mut converging_groups = 0;
255 let num_groups = recent_magnitudes[0].len();
256
257 for group_idx in 0..num_groups {
258 let group_trend: Vec<A> = recent_magnitudes
259 .iter()
260 .rev()
261 .map(|step| step[group_idx])
262 .collect();
263
264 let is_decreasing = group_trend
266 .windows(2)
267 .map(|window| window[1] < window[0])
268 .filter(|&x| x)
269 .count()
270 >= group_trend.len() / 2;
271
272 if is_decreasing {
273 converging_groups += 1;
274 }
275 }
276
277 converging_groups >= num_groups / 2
278 }
279
280 fn calculate_oscillation_frequency(&self) -> f64 {
282 if self.gradient_directions.len() < 3 {
283 return 0.0;
284 }
285
286 let mut sign_changes = 0;
287 let mut prev_positive = None;
288
289 for &direction in &self.gradient_directions {
290 let is_positive = direction >= A::zero();
291 if let Some(prev) = prev_positive {
292 if prev != is_positive {
293 sign_changes += 1;
294 }
295 }
296 prev_positive = Some(is_positive);
297 }
298
299 sign_changes as f64 / (self.gradient_directions.len() - 1) as f64
300 }
301
302 fn calculate_stability_score(&self) -> f64 {
304 if self.gradient_directions.is_empty() {
305 return 1.0;
306 }
307
308 let direction_consistency = self
310 .gradient_directions
311 .iter()
312 .fold(A::zero(), |acc, &x| acc + x.abs())
313 / A::from(self.gradient_directions.len()).unwrap();
314
315 let magnitude_consistency = if !self.gradient_magnitudes.is_empty() {
316 let all_magnitudes: Vec<A> = self
317 .gradient_magnitudes
318 .iter()
319 .flat_map(|step| step.iter())
320 .copied()
321 .collect();
322
323 if all_magnitudes.len() > 1 {
324 let mean = all_magnitudes.iter().fold(A::zero(), |acc, &x| acc + x)
325 / A::from(all_magnitudes.len()).unwrap();
326 let variance = all_magnitudes
327 .iter()
328 .map(|&x| (x - mean) * (x - mean))
329 .fold(A::zero(), |acc, x| acc + x)
330 / A::from(all_magnitudes.len()).unwrap();
331 let cv = if mean > A::zero() {
332 variance.sqrt() / mean
333 } else {
334 A::zero()
335 };
336 (A::one() / (A::one() + cv)).to_f64().unwrap_or(0.0)
338 } else {
339 1.0
340 }
341 } else {
342 1.0
343 };
344
345 let direction_score = direction_consistency.to_f64().unwrap_or(0.0).abs();
346 (direction_score + magnitude_consistency) / 2.0
347 }
348
349 pub fn step_count(&self) -> usize {
351 self.step_count
352 }
353
354 pub fn clear(&mut self) {
356 self.gradient_magnitudes.clear();
357 self.gradient_directions.clear();
358 self.parameter_updates.clear();
359 self.step_count = 0;
360 self.cache_valid = false;
361 self.stats_cache = None;
362 }
363
364 pub fn export_for_visualization(&self) -> VisualizationData<A> {
366 let magnitude_series: Vec<Vec<A>> = if !self.gradient_magnitudes.is_empty() {
367 let num_groups = self.gradient_magnitudes[0].len();
368 (0..num_groups)
369 .map(|group_idx| {
370 self.gradient_magnitudes
371 .iter()
372 .map(|step| step[group_idx])
373 .collect()
374 })
375 .collect()
376 } else {
377 Vec::new()
378 };
379
380 VisualizationData {
381 step_indices: (0..self.step_count).collect(),
382 magnitude_series,
383 direction_similarities: self.gradient_directions.iter().copied().collect(),
384 }
385 }
386}
387
388#[derive(Debug, Clone)]
390pub struct GradientFlowStats<A: Float> {
391 pub step_count: usize,
393 pub per_group_stats: Vec<ParameterGroupStats<A>>,
395 pub mean_direction_similarity: A,
397 pub direction_variance: A,
399 pub direction_std_dev: A,
401 pub is_converging: bool,
403 pub oscillation_frequency: f64,
405 pub stability_score: f64,
407 pub direction_history: Vec<A>,
409}
410
411#[derive(Debug, Clone)]
413pub struct ParameterGroupStats<A: Float> {
414 pub mean_magnitude: A,
416 pub variance: A,
418 pub std_dev: A,
420 pub max_magnitude: A,
422 pub min_magnitude: A,
424 pub magnitude_history: Vec<A>,
426}
427
428#[derive(Debug, Clone)]
430pub struct VisualizationData<A: Float> {
431 pub step_indices: Vec<usize>,
433 pub magnitude_series: Vec<Vec<A>>,
435 pub direction_similarities: Vec<A>,
437}
438
439pub struct OptimizerBenchmark<A: Float> {
441 test_functions: Vec<TestFunction<A>>,
443 results: Vec<BenchmarkResult<A>>,
445}
446
447impl<A: Float + ScalarOperand + Debug + Send + Sync> OptimizerBenchmark<A> {
448 pub fn new() -> Self {
450 Self {
451 test_functions: Vec::new(),
452 results: Vec::new(),
453 }
454 }
455
456 pub fn add_test_function(&mut self, testfunction: TestFunction<A>) {
458 self.test_functions.push(testfunction);
459 }
460
461 pub fn add_standard_test_functions(&mut self) {
463 self.add_test_function(TestFunction {
465 name: "Quadratic".to_string(),
466 dimension: 10,
467 function: Box::new(|x: &Array1<A>| x.mapv(|val| val * val).sum()),
468 gradient: Box::new(|x: &Array1<A>| x.mapv(|val| A::from(2.0).unwrap() * val)),
469 optimal_value: Some(A::zero()),
470 optimal_point: Some(Array1::zeros(10)),
471 });
472
473 self.add_test_function(TestFunction {
475 name: "Rosenbrock".to_string(),
476 dimension: 2,
477 function: Box::new(|x: &Array1<A>| {
478 let a = A::one();
479 let b = A::from(100.0).unwrap();
480 let term1 = (a - x[0]) * (a - x[0]);
481 let term2 = b * (x[1] - x[0] * x[0]) * (x[1] - x[0] * x[0]);
482 term1 + term2
483 }),
484 gradient: Box::new(|x: &Array1<A>| {
485 let a = A::one();
486 let b = A::from(100.0).unwrap();
487 let grad_x = A::from(-2.0).unwrap() * (a - x[0])
488 - A::from(4.0).unwrap() * b * x[0] * (x[1] - x[0] * x[0]);
489 let grad_y = A::from(2.0).unwrap() * b * (x[1] - x[0] * x[0]);
490 Array1::from_vec(vec![grad_x, grad_y])
491 }),
492 optimal_value: Some(A::zero()),
493 optimal_point: Some(Array1::from_vec(vec![A::one(), A::one()])),
494 });
495
496 self.add_test_function(TestFunction {
498 name: "Sphere".to_string(),
499 dimension: 5,
500 function: Box::new(|x: &Array1<A>| x.mapv(|val| val * val).sum()),
501 gradient: Box::new(|x: &Array1<A>| x.mapv(|val| A::from(2.0).unwrap() * val)),
502 optimal_value: Some(A::zero()),
503 optimal_point: Some(Array1::zeros(5)),
504 });
505 }
506
507 pub fn run_benchmark<F>(
509 &mut self,
510 optimizername: String,
511 mut optimization_step: F,
512 max_iterations: usize,
513 tolerance: A,
514 ) -> Result<Vec<BenchmarkResult<A>>>
515 where
516 F: FnMut(&Array1<A>, &Array1<A>) -> Array1<A>,
517 {
518 let mut results = Vec::new();
519
520 for testfunction in &self.test_functions {
521 let mut x = Array1::from_vec(
522 (0..testfunction.dimension)
523 .map(|_| A::from(0.5).unwrap())
524 .collect(),
525 );
526
527 let mut function_values = Vec::new();
528 let mut gradient_norms = Vec::new();
529 let mut convergence_step = None;
530
531 let start_time = std::time::Instant::now();
532
533 for iteration in 0..max_iterations {
534 let f_val = (testfunction.function)(&x);
535 let grad = (testfunction.gradient)(&x);
536 let grad_norm = grad.mapv(|g| g * g).sum().sqrt();
537
538 function_values.push(f_val);
539 gradient_norms.push(grad_norm);
540
541 if grad_norm < tolerance {
543 convergence_step = Some(iteration);
544 break;
545 }
546
547 x = optimization_step(&x, &grad);
549 }
550
551 let elapsed = start_time.elapsed();
552
553 let final_error = if let Some(optimal_value) = testfunction.optimal_value {
554 (function_values.last().copied().unwrap() - optimal_value).abs()
555 } else {
556 A::zero()
557 };
558
559 let result = BenchmarkResult {
560 optimizername: optimizername.clone(),
561 function_name: testfunction.name.clone(),
562 converged: convergence_step.is_some(),
563 convergence_step,
564 final_function_value: *function_values.last().unwrap(),
565 final_gradient_norm: *gradient_norms.last().unwrap(),
566 final_error,
567 iterations_taken: function_values.len(),
568 elapsed_time: elapsed,
569 function_evaluations: function_values.len(),
570 function_value_history: function_values,
571 gradient_norm_history: gradient_norms,
572 };
573
574 results.push(result.clone());
575 }
576
577 self.results.extend(results.clone());
578 Ok(results)
579 }
580
581 pub fn get_results(&self) -> &[BenchmarkResult<A>] {
583 &self.results
584 }
585
586 pub fn clear_results(&mut self) {
588 self.results.clear();
589 }
590
591 pub fn generate_report(&self) -> BenchmarkReport<A> {
593 let mut optimizer_performance = std::collections::HashMap::new();
594
595 for result in &self.results {
596 let entry = optimizer_performance
597 .entry(result.optimizername.clone())
598 .or_insert_with(|| OptimizerPerformance {
599 total_runs: 0,
600 successful_runs: 0,
601 average_iterations: 0.0,
602 average_final_error: A::zero(),
603 average_time: std::time::Duration::from_secs(0),
604 });
605
606 entry.total_runs += 1;
607 if result.converged {
608 entry.successful_runs += 1;
609 }
610 entry.average_iterations += result.iterations_taken as f64;
611 entry.average_final_error = entry.average_final_error + result.final_error;
612 entry.average_time += result.elapsed_time;
613 }
614
615 for performance in optimizer_performance.values_mut() {
617 if performance.total_runs > 0 {
618 performance.average_iterations /= performance.total_runs as f64;
619 performance.average_final_error =
620 performance.average_final_error / A::from(performance.total_runs).unwrap();
621 performance.average_time /= performance.total_runs as u32;
622 }
623 }
624
625 BenchmarkReport {
626 total_tests: self.results.len(),
627 optimizer_performance,
628 }
629 }
630}
631
632pub struct TestFunction<A: Float> {
634 pub name: String,
636 pub dimension: usize,
638 pub function: ObjectiveFunction<A>,
640 pub gradient: GradientFunction<A>,
642 pub optimal_value: Option<A>,
644 pub optimal_point: Option<Array1<A>>,
646}
647
648#[derive(Debug, Clone)]
650pub struct BenchmarkResult<A: Float> {
651 pub optimizername: String,
653 pub function_name: String,
655 pub converged: bool,
657 pub convergence_step: Option<usize>,
659 pub final_function_value: A,
661 pub final_gradient_norm: A,
663 pub final_error: A,
665 pub iterations_taken: usize,
667 pub elapsed_time: std::time::Duration,
669 pub function_evaluations: usize,
671 pub function_value_history: Vec<A>,
673 pub gradient_norm_history: Vec<A>,
675}
676
677#[derive(Debug, Clone)]
679pub struct OptimizerPerformance<A: Float> {
680 pub total_runs: usize,
682 pub successful_runs: usize,
684 pub average_iterations: f64,
686 pub average_final_error: A,
688 pub average_time: std::time::Duration,
690}
691
692#[derive(Debug)]
694pub struct BenchmarkReport<A: Float> {
695 pub total_tests: usize,
697 pub optimizer_performance: std::collections::HashMap<String, OptimizerPerformance<A>>,
699}
700
701impl<A: Float + Send + Sync> BenchmarkReport<A> {
702 pub fn get_success_rate(&self, optimizername: &str) -> Option<f64> {
704 self.optimizer_performance.get(optimizername).map(|perf| {
705 if perf.total_runs > 0 {
706 perf.successful_runs as f64 / perf.total_runs as f64
707 } else {
708 0.0
709 }
710 })
711 }
712
713 pub fn compare_optimizers(&self, opt1: &str, opt2: &str) -> Option<OptimizerComparison<A>> {
715 let perf1 = self.optimizer_performance.get(opt1)?;
716 let perf2 = self.optimizer_performance.get(opt2)?;
717
718 Some(OptimizerComparison {
719 optimizer1: opt1.to_string(),
720 optimizer2: opt2.to_string(),
721 success_rate_diff: self.get_success_rate(opt1).unwrap_or(0.0)
722 - self.get_success_rate(opt2).unwrap_or(0.0),
723 avg_iterations_diff: perf1.average_iterations - perf2.average_iterations,
724 avg_error_diff: perf1.average_final_error - perf2.average_final_error,
725 })
726 }
727}
728
729#[derive(Debug, Clone)]
731pub struct OptimizerComparison<A: Float> {
732 pub optimizer1: String,
734 pub optimizer2: String,
736 pub success_rate_diff: f64,
738 pub avg_iterations_diff: f64,
740 pub avg_error_diff: A,
742}
743
744impl<A: Float + ScalarOperand + Debug + Send + Sync> Default for OptimizerBenchmark<A> {
745 fn default() -> Self {
746 Self::new()
747 }
748}
749
750pub mod visualization {
752 use super::*;
753 use std::fmt::Write;
754
755 #[derive(Debug)]
757 pub struct OptimizerStateVisualizer<A: Float, D: Dimension> {
758 parameter_history: VecDeque<Vec<Array<A, D>>>,
760 state_history: VecDeque<OptimizerStateSnapshot<A>>,
762 learning_rate_history: VecDeque<A>,
764 loss_history: VecDeque<A>,
766 _maxhistory: usize,
768 step_count: usize,
770 }
771
772 impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> OptimizerStateVisualizer<A, D> {
773 pub fn new(_maxhistory: usize) -> Self {
775 Self {
776 parameter_history: VecDeque::with_capacity(_maxhistory),
777 state_history: VecDeque::with_capacity(_maxhistory),
778 learning_rate_history: VecDeque::with_capacity(_maxhistory),
779 loss_history: VecDeque::with_capacity(_maxhistory),
780 _maxhistory,
781 step_count: 0,
782 }
783 }
784
785 pub fn record_step(
787 &mut self,
788 parameters: &[Array<A, D>],
789 state_snapshot: OptimizerStateSnapshot<A>,
790 learning_rate: A,
791 loss_value: A,
792 ) {
793 self.step_count += 1;
794
795 self.parameter_history.push_back(parameters.to_vec());
797 if self.parameter_history.len() > self._maxhistory {
798 self.parameter_history.pop_front();
799 }
800
801 self.state_history.push_back(state_snapshot);
803 if self.state_history.len() > self._maxhistory {
804 self.state_history.pop_front();
805 }
806
807 self.learning_rate_history.push_back(learning_rate);
809 if self.learning_rate_history.len() > self._maxhistory {
810 self.learning_rate_history.pop_front();
811 }
812
813 self.loss_history.push_back(loss_value);
815 if self.loss_history.len() > self._maxhistory {
816 self.loss_history.pop_front();
817 }
818 }
819
820 pub fn generate_convergence_plot(&self, width: usize, height: usize) -> String {
822 if self.loss_history.is_empty() {
823 return "No data to visualize".to_string();
824 }
825
826 let mut plot = String::new();
827
828 let min_loss = self
830 .loss_history
831 .iter()
832 .fold(A::infinity(), |acc, &x| acc.min(x));
833 let max_loss = self
834 .loss_history
835 .iter()
836 .fold(A::neg_infinity(), |acc, &x| acc.max(x));
837
838 let loss_range = max_loss - min_loss;
839
840 writeln!(plot, "Loss Convergence (Steps: {})", self.step_count).unwrap();
841 writeln!(
842 plot,
843 "Max: {:.6}, Min: {:.6}",
844 max_loss.to_f64().unwrap_or(0.0),
845 min_loss.to_f64().unwrap_or(0.0)
846 )
847 .unwrap();
848 writeln!(plot, "{}", "=".repeat(width + 10)).unwrap();
849
850 for row in 0..height {
852 let y_value =
853 max_loss - (A::from(row).unwrap() / A::from(height - 1).unwrap()) * loss_range;
854 write!(plot, "{:8.3} |", y_value.to_f64().unwrap_or(0.0)).unwrap();
855
856 for col in 0..width {
857 let step_index = (col * self.loss_history.len()) / width;
858 if step_index < self.loss_history.len() {
859 let loss_val = self.loss_history[step_index];
860 let normalized_y = ((max_loss - loss_val) / loss_range
861 * A::from(height - 1).unwrap())
862 .to_usize()
863 .unwrap_or(0);
864
865 if normalized_y == row {
866 write!(plot, "*").unwrap();
867 } else {
868 write!(plot, " ").unwrap();
869 }
870 } else {
871 write!(plot, " ").unwrap();
872 }
873 }
874 writeln!(plot, "|").unwrap();
875 }
876
877 writeln!(plot, " {}", "-".repeat(width)).unwrap();
878 writeln!(
879 plot,
880 " 0{:width$}Steps",
881 self.step_count,
882 width = width - 10
883 )
884 .unwrap();
885
886 plot
887 }
888
889 pub fn generate_learning_rate_plot(&self, width: usize, height: usize) -> String {
891 if self.learning_rate_history.is_empty() {
892 return "No learning rate data to visualize".to_string();
893 }
894
895 let mut plot = String::new();
896
897 let min_lr = self
898 .learning_rate_history
899 .iter()
900 .fold(A::infinity(), |acc, &x| acc.min(x));
901 let max_lr = self
902 .learning_rate_history
903 .iter()
904 .fold(A::neg_infinity(), |acc, &x| acc.max(x));
905
906 let lr_range = max_lr - min_lr;
907
908 writeln!(plot, "Learning Rate Schedule").unwrap();
909 writeln!(
910 plot,
911 "Max: {:.6}, Min: {:.6}",
912 max_lr.to_f64().unwrap_or(0.0),
913 min_lr.to_f64().unwrap_or(0.0)
914 )
915 .unwrap();
916 writeln!(plot, "{}", "=".repeat(width + 10)).unwrap();
917
918 for row in 0..height {
919 let y_value =
920 max_lr - (A::from(row).unwrap() / A::from(height - 1).unwrap()) * lr_range;
921 write!(plot, "{:8.3} |", y_value.to_f64().unwrap_or(0.0)).unwrap();
922
923 for col in 0..width {
924 let step_index = (col * self.learning_rate_history.len()) / width;
925 if step_index < self.learning_rate_history.len() {
926 let lr_val = self.learning_rate_history[step_index];
927 let normalized_y = if lr_range > A::zero() {
928 ((max_lr - lr_val) / lr_range * A::from(height - 1).unwrap())
929 .to_usize()
930 .unwrap_or(0)
931 } else {
932 height / 2
933 };
934
935 if normalized_y == row {
936 write!(plot, "*").unwrap();
937 } else {
938 write!(plot, " ").unwrap();
939 }
940 } else {
941 write!(plot, " ").unwrap();
942 }
943 }
944 writeln!(plot, "|").unwrap();
945 }
946
947 writeln!(plot, " {}", "-".repeat(width)).unwrap();
948 writeln!(
949 plot,
950 " 0{:width$}Steps",
951 self.step_count,
952 width = width - 10
953 )
954 .unwrap();
955
956 plot
957 }
958
959 pub fn generate_parameter_heatmap(&self, width: usize, height: usize) -> String {
961 if self.parameter_history.is_empty() {
962 return "No parameter data to visualize".to_string();
963 }
964
965 let mut plot = String::new();
966 writeln!(plot, "Parameter Evolution Heatmap").unwrap();
967 writeln!(plot, "{}", "=".repeat(width + 5)).unwrap();
968
969 let all_params: Vec<A> = self
971 .parameter_history
972 .iter()
973 .flat_map(|step| step.iter().flat_map(|array| array.iter().copied()))
974 .collect();
975
976 if all_params.is_empty() {
977 return "No parameter data available".to_string();
978 }
979
980 let min_param = all_params.iter().fold(A::infinity(), |acc, &x| acc.min(x));
981 let max_param = all_params
982 .iter()
983 .fold(A::neg_infinity(), |acc, &x| acc.max(x));
984 let param_range = max_param - min_param;
985
986 let num_steps = self.parameter_history.len().min(width);
988 let num_params = if !self.parameter_history.is_empty() {
989 self.parameter_history[0]
990 .iter()
991 .map(|arr| arr.len())
992 .sum::<usize>()
993 .min(height)
994 } else {
995 0
996 };
997
998 for param_idx in 0..num_params {
999 write!(plot, "P{:3} |", param_idx).unwrap();
1000
1001 for step_idx in 0..num_steps {
1002 let step_data = &self.parameter_history[step_idx];
1003
1004 let mut flat_idx = 0;
1006 let mut found_value = None;
1007
1008 for array in step_data {
1009 if flat_idx + array.len() > param_idx {
1010 let local_idx = param_idx - flat_idx;
1011 if let Some(&value) = array.iter().nth(local_idx) {
1012 found_value = Some(value);
1013 break;
1014 }
1015 }
1016 flat_idx += array.len();
1017 }
1018
1019 if let Some(value) = found_value {
1020 let normalized = if param_range > A::zero() {
1021 ((value - min_param) / param_range).to_f64().unwrap_or(0.0)
1022 } else {
1023 0.5
1024 };
1025
1026 let char = match (normalized * 4.0) as i32 {
1027 0 => ' ',
1028 1 => '.',
1029 2 => ':',
1030 3 => '*',
1031 _ => '#',
1032 };
1033 write!(plot, "{}", char).unwrap();
1034 } else {
1035 write!(plot, " ").unwrap();
1036 }
1037 }
1038 writeln!(plot, "|").unwrap();
1039 }
1040
1041 writeln!(plot, " {}", "-".repeat(num_steps)).unwrap();
1042 writeln!(plot, " Legend: ' ' = Low, '.' < ':' < '*' < '#' = High").unwrap();
1043 writeln!(
1044 plot,
1045 " Range: {:.6} to {:.6}",
1046 min_param.to_f64().unwrap_or(0.0),
1047 max_param.to_f64().unwrap_or(0.0)
1048 )
1049 .unwrap();
1050
1051 plot
1052 }
1053
1054 pub fn generate_state_summary(&self) -> String {
1056 let mut summary = String::new();
1057
1058 writeln!(summary, "Optimizer State Summary").unwrap();
1059 writeln!(summary, "======================").unwrap();
1060 writeln!(summary, "Total Steps: {}", self.step_count).unwrap();
1061 writeln!(summary, "History Length: {}", self.parameter_history.len()).unwrap();
1062
1063 if let Some(current_loss) = self.loss_history.back() {
1064 writeln!(
1065 summary,
1066 "Current Loss: {:.6}",
1067 current_loss.to_f64().unwrap_or(0.0)
1068 )
1069 .unwrap();
1070 }
1071
1072 if let Some(current_lr) = self.learning_rate_history.back() {
1073 writeln!(
1074 summary,
1075 "Current Learning Rate: {:.6}",
1076 current_lr.to_f64().unwrap_or(0.0)
1077 )
1078 .unwrap();
1079 }
1080
1081 if !self.loss_history.is_empty() {
1083 let min_loss = self
1084 .loss_history
1085 .iter()
1086 .fold(A::infinity(), |acc, &x| acc.min(x));
1087 let max_loss = self
1088 .loss_history
1089 .iter()
1090 .fold(A::neg_infinity(), |acc, &x| acc.max(x));
1091 let avg_loss = self.loss_history.iter().fold(A::zero(), |acc, &x| acc + x)
1092 / A::from(self.loss_history.len()).unwrap();
1093
1094 writeln!(summary, "\nLoss Statistics:").unwrap();
1095 writeln!(summary, " Min: {:.6}", min_loss.to_f64().unwrap_or(0.0)).unwrap();
1096 writeln!(summary, " Max: {:.6}", max_loss.to_f64().unwrap_or(0.0)).unwrap();
1097 writeln!(summary, " Avg: {:.6}", avg_loss.to_f64().unwrap_or(0.0)).unwrap();
1098
1099 if self.loss_history.len() > 1 {
1101 let first_loss = self.loss_history[0];
1102 let last_loss = *self.loss_history.back().unwrap();
1103 let improvement = first_loss - last_loss;
1104 let improvement_rate = improvement / first_loss;
1105 writeln!(
1106 summary,
1107 " Improvement: {:.6} ({:.2}%)",
1108 improvement.to_f64().unwrap_or(0.0),
1109 (improvement_rate.to_f64().unwrap_or(0.0) * 100.0)
1110 )
1111 .unwrap();
1112 }
1113 }
1114
1115 if !self.parameter_history.is_empty() {
1117 let current_params = self.parameter_history.back().unwrap();
1118 let total_params: usize = current_params.iter().map(|arr| arr.len()).sum();
1119 writeln!(summary, "\nParameter Statistics:").unwrap();
1120 writeln!(summary, " Total Parameters: {}", total_params).unwrap();
1121 writeln!(summary, " Parameter Groups: {}", current_params.len()).unwrap();
1122
1123 for (i, array) in current_params.iter().enumerate() {
1125 let l2_norm = array.mapv(|x| x * x).sum().sqrt();
1126 writeln!(
1127 summary,
1128 " Group {} L2 Norm: {:.6}",
1129 i,
1130 l2_norm.to_f64().unwrap_or(0.0)
1131 )
1132 .unwrap();
1133 }
1134 }
1135
1136 if !self.state_history.is_empty() {
1138 writeln!(summary, "\nOptimizer State:").unwrap();
1139 if let Some(latest_state) = self.state_history.back() {
1140 writeln!(
1141 summary,
1142 " Momentum Norm: {:.6}",
1143 latest_state.momentum_norm.to_f64().unwrap_or(0.0)
1144 )
1145 .unwrap();
1146 writeln!(
1147 summary,
1148 " Velocity Norm: {:.6}",
1149 latest_state.velocity_norm.to_f64().unwrap_or(0.0)
1150 )
1151 .unwrap();
1152 writeln!(
1153 summary,
1154 " Step Size: {:.6}",
1155 latest_state.effective_step_size.to_f64().unwrap_or(0.0)
1156 )
1157 .unwrap();
1158 writeln!(
1159 summary,
1160 " Beta1: {:.6}",
1161 latest_state.beta1.to_f64().unwrap_or(0.0)
1162 )
1163 .unwrap();
1164 writeln!(
1165 summary,
1166 " Beta2: {:.6}",
1167 latest_state.beta2.to_f64().unwrap_or(0.0)
1168 )
1169 .unwrap();
1170 }
1171 }
1172
1173 summary
1174 }
1175
1176 pub fn export_data(&self) -> VisualizationExport<A> {
1178 VisualizationExport {
1179 step_indices: (0..self.step_count).collect(),
1180 loss_history: self.loss_history.iter().copied().collect(),
1181 learning_rate_history: self.learning_rate_history.iter().copied().collect(),
1182 parameter_norms: self
1183 .parameter_history
1184 .iter()
1185 .map(|step| {
1186 step.iter()
1187 .map(|array| array.mapv(|x| x * x).sum().sqrt())
1188 .collect()
1189 })
1190 .collect(),
1191 state_snapshots: self.state_history.iter().cloned().collect(),
1192 }
1193 }
1194
1195 pub fn clear(&mut self) {
1197 self.parameter_history.clear();
1198 self.state_history.clear();
1199 self.learning_rate_history.clear();
1200 self.loss_history.clear();
1201 self.step_count = 0;
1202 }
1203
1204 pub fn step_count(&self) -> usize {
1206 self.step_count
1207 }
1208 }
1209
1210 #[derive(Debug, Clone)]
1212 pub struct OptimizerStateSnapshot<A: Float> {
1213 pub momentum_norm: A,
1215 pub velocity_norm: A,
1217 pub effective_step_size: A,
1219 pub beta1: A,
1221 pub beta2: A,
1223 pub custom_fields: std::collections::HashMap<String, A>,
1225 }
1226
1227 impl<A: Float + Send + Sync> OptimizerStateSnapshot<A> {
1228 pub fn new() -> Self {
1230 Self {
1231 momentum_norm: A::zero(),
1232 velocity_norm: A::zero(),
1233 effective_step_size: A::zero(),
1234 beta1: A::zero(),
1235 beta2: A::zero(),
1236 custom_fields: std::collections::HashMap::new(),
1237 }
1238 }
1239
1240 pub fn with_custom_field(mut self, name: String, value: A) -> Self {
1242 self.custom_fields.insert(name, value);
1243 self
1244 }
1245 }
1246
1247 impl<A: Float + Send + Sync> Default for OptimizerStateSnapshot<A> {
1248 fn default() -> Self {
1249 Self::new()
1250 }
1251 }
1252
1253 #[derive(Debug, Clone)]
1255 pub struct VisualizationExport<A: Float> {
1256 pub step_indices: Vec<usize>,
1258 pub loss_history: Vec<A>,
1260 pub learning_rate_history: Vec<A>,
1262 pub parameter_norms: Vec<Vec<A>>,
1264 pub state_snapshots: Vec<OptimizerStateSnapshot<A>>,
1266 }
1267
1268 #[derive(Debug)]
1270 pub struct OptimizerDashboard<A: Float, D: Dimension> {
1271 visualizers: std::collections::HashMap<String, OptimizerStateVisualizer<A, D>>,
1273 #[allow(dead_code)]
1275 comparison_metrics: Vec<ComparisonMetric<A>>,
1276 }
1277
1278 impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> OptimizerDashboard<A, D> {
1279 pub fn new() -> Self {
1281 Self {
1282 visualizers: std::collections::HashMap::new(),
1283 comparison_metrics: Vec::new(),
1284 }
1285 }
1286
1287 pub fn add_optimizer(&mut self, name: String, maxhistory: usize) {
1289 self.visualizers
1290 .insert(name, OptimizerStateVisualizer::new(maxhistory));
1291 }
1292
1293 pub fn record_optimizer_step(
1295 &mut self,
1296 optimizername: &str,
1297 parameters: &[Array<A, D>],
1298 state_snapshot: OptimizerStateSnapshot<A>,
1299 learning_rate: A,
1300 loss_value: A,
1301 ) -> Result<()> {
1302 if let Some(visualizer) = self.visualizers.get_mut(optimizername) {
1303 visualizer.record_step(parameters, state_snapshot, learning_rate, loss_value);
1304 Ok(())
1305 } else {
1306 Err(OptimError::InvalidConfig(format!(
1307 "Optimizer '{}' not found in dashboard",
1308 optimizername
1309 )))
1310 }
1311 }
1312
1313 pub fn generate_comparison_report(&self) -> String {
1315 let mut report = String::new();
1316
1317 writeln!(report, "Optimizer Comparison Dashboard").unwrap();
1318 writeln!(report, "===============================").unwrap();
1319
1320 for (name, visualizer) in &self.visualizers {
1321 writeln!(report, "\n{}", name).unwrap();
1322 writeln!(report, "{}", "-".repeat(name.len())).unwrap();
1323
1324 if let Some(current_loss) = visualizer.loss_history.back() {
1325 writeln!(
1326 report,
1327 "Current Loss: {:.6}",
1328 current_loss.to_f64().unwrap_or(0.0)
1329 )
1330 .unwrap();
1331 }
1332
1333 writeln!(report, "Steps: {}", visualizer.step_count).unwrap();
1334
1335 if visualizer.loss_history.len() > 1 {
1337 let first_loss = visualizer.loss_history[0];
1338 let last_loss = *visualizer.loss_history.back().unwrap();
1339 let improvement = first_loss - last_loss;
1340 writeln!(
1341 report,
1342 "Total Improvement: {:.6}",
1343 improvement.to_f64().unwrap_or(0.0)
1344 )
1345 .unwrap();
1346 }
1347 }
1348
1349 if !self.visualizers.is_empty() {
1351 writeln!(report, "\nBest Performers:").unwrap();
1352 writeln!(report, "================").unwrap();
1353
1354 let best_current_loss = self
1355 .visualizers
1356 .iter()
1357 .filter_map(|(name, viz)| viz.loss_history.back().map(|&loss| (name, loss)))
1358 .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
1359
1360 if let Some((best_name, best_loss)) = best_current_loss {
1361 writeln!(
1362 report,
1363 "Lowest Current Loss: {} ({:.6})",
1364 best_name,
1365 best_loss.to_f64().unwrap_or(0.0)
1366 )
1367 .unwrap();
1368 }
1369 }
1370
1371 report
1372 }
1373
1374 pub fn get_visualizer(
1376 &self,
1377 optimizername: &str,
1378 ) -> Option<&OptimizerStateVisualizer<A, D>> {
1379 self.visualizers.get(optimizername)
1380 }
1381
1382 pub fn get_visualizer_mut(
1384 &mut self,
1385 optimizername: &str,
1386 ) -> Option<&mut OptimizerStateVisualizer<A, D>> {
1387 self.visualizers.get_mut(optimizername)
1388 }
1389
1390 pub fn list_optimizers(&self) -> Vec<&String> {
1392 self.visualizers.keys().collect()
1393 }
1394 }
1395
1396 impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> Default
1397 for OptimizerDashboard<A, D>
1398 {
1399 fn default() -> Self {
1400 Self::new()
1401 }
1402 }
1403
1404 #[derive(Debug, Clone)]
1406 pub struct ComparisonMetric<A: Float> {
1407 pub name: String,
1409 pub values: std::collections::HashMap<String, A>,
1411 }
1412}
1413
1414#[cfg(test)]
1415mod tests {
1416 use super::*;
1417 use approx::assert_relative_eq;
1418
1419 #[test]
1420 fn test_gradient_flow_analyzer() {
1421 let mut analyzer = GradientFlowAnalyzer::new(100);
1422
1423 let gradients1 = vec![Array1::from_vec(vec![1.0, 2.0])];
1424 let updates1 = vec![Array1::from_vec(vec![0.1, 0.2])];
1425
1426 let gradients2 = vec![Array1::from_vec(vec![0.8, 1.6])];
1427 let updates2 = vec![Array1::from_vec(vec![0.08, 0.16])];
1428
1429 analyzer.record_step(&gradients1, &updates1).unwrap();
1430 analyzer.record_step(&gradients2, &updates2).unwrap();
1431
1432 assert_eq!(analyzer.step_count(), 2);
1433
1434 let stats = analyzer.get_stats();
1435 assert_eq!(stats.step_count, 2);
1436 assert_eq!(stats.per_group_stats.len(), 1);
1437
1438 let expected_mag1 = (1.0_f64 * 1.0 + 2.0 * 2.0).sqrt();
1440 let expected_mag2 = (0.8_f64 * 0.8 + 1.6 * 1.6).sqrt();
1441
1442 assert_relative_eq!(
1443 stats.per_group_stats[0].magnitude_history[0],
1444 expected_mag1,
1445 epsilon = 1e-6
1446 );
1447 assert_relative_eq!(
1448 stats.per_group_stats[0].magnitude_history[1],
1449 expected_mag2,
1450 epsilon = 1e-6
1451 );
1452
1453 assert!(stats.mean_direction_similarity > 0.9);
1455 }
1456
1457 #[test]
1458 #[ignore = "timeout"]
1459 fn test_benchmark_quadratic() {
1460 let mut benchmark = OptimizerBenchmark::new();
1461 benchmark.add_standard_test_functions();
1462
1463 let learning_rate = 0.01;
1465 let mut step_function = |x: &Array1<f64>, grad: &Array1<f64>| x - &(grad * learning_rate);
1466
1467 let results = benchmark
1468 .run_benchmark(
1469 "GradientDescent".to_string(),
1470 &mut step_function,
1471 1000,
1472 1e-6,
1473 )
1474 .unwrap();
1475
1476 assert!(!results.is_empty());
1477
1478 let quadratic_result = results
1480 .iter()
1481 .find(|r| r.function_name == "Quadratic")
1482 .unwrap();
1483
1484 assert!(quadratic_result.converged);
1485 assert!(quadratic_result.final_function_value < 1e-3);
1486 }
1487
1488 #[test]
1489 fn test_cosine_similarity() {
1490 let analyzer = GradientFlowAnalyzer::<f64, scirs2_core::ndarray::Ix1>::new(10);
1491
1492 let arrays1 = vec![Array1::from_vec(vec![1.0, 0.0])];
1493 let arrays2 = vec![Array1::from_vec(vec![1.0, 0.0])]; let similarity = analyzer
1495 .calculate_cosine_similarity(&arrays1, &arrays2)
1496 .unwrap();
1497 assert_relative_eq!(similarity, 1.0, epsilon = 1e-6);
1498
1499 let arrays3 = vec![Array1::from_vec(vec![-1.0, 0.0])]; let similarity2 = analyzer
1501 .calculate_cosine_similarity(&arrays1, &arrays3)
1502 .unwrap();
1503 assert_relative_eq!(similarity2, -1.0, epsilon = 1e-6);
1504
1505 let arrays4 = vec![Array1::from_vec(vec![0.0, 1.0])]; let similarity3 = analyzer
1507 .calculate_cosine_similarity(&arrays1, &arrays4)
1508 .unwrap();
1509 assert_relative_eq!(similarity3, 0.0, epsilon = 1e-6);
1510 }
1511
1512 #[test]
1513 #[ignore = "timeout"]
1514 fn test_benchmark_report() {
1515 let mut benchmark = OptimizerBenchmark::new();
1516 benchmark.add_test_function(TestFunction {
1517 name: "Simple".to_string(),
1518 dimension: 2,
1519 function: Box::new(|x: &Array1<f64>| x[0] * x[0] + x[1] * x[1]),
1520 gradient: Box::new(|x: &Array1<f64>| Array1::from_vec(vec![2.0 * x[0], 2.0 * x[1]])),
1521 optimal_value: Some(0.0),
1522 optimal_point: Some(Array1::zeros(2)),
1523 });
1524
1525 let mut step1 = |x: &Array1<f64>, grad: &Array1<f64>| x - &(grad * 0.1);
1527 let mut step2 = |x: &Array1<f64>, grad: &Array1<f64>| x - &(grad * 0.05);
1528
1529 benchmark
1530 .run_benchmark("Fast".to_string(), &mut step1, 100, 1e-3)
1531 .unwrap();
1532 benchmark
1533 .run_benchmark("Slow".to_string(), &mut step2, 100, 1e-3)
1534 .unwrap();
1535
1536 let report = benchmark.generate_report();
1537 assert_eq!(report.total_tests, 2);
1538 assert!(report.optimizer_performance.contains_key("Fast"));
1539 assert!(report.optimizer_performance.contains_key("Slow"));
1540
1541 let comparison = report.compare_optimizers("Fast", "Slow").unwrap();
1542 assert_eq!(comparison.optimizer1, "Fast");
1543 assert_eq!(comparison.optimizer2, "Slow");
1544 }
1545
1546 #[test]
1547 fn test_visualization_data_export() {
1548 let mut analyzer = GradientFlowAnalyzer::new(10);
1549
1550 let _gradients = [Array1::from_vec(vec![1.0, 2.0])];
1551 let _updates = [Array1::from_vec(vec![0.1, 0.2])];
1552
1553 for i in 0..5 {
1554 let scale = 1.0 / (i + 1) as f64;
1555 let scaled_grad = vec![Array1::from_vec(vec![scale, 2.0 * scale])];
1556 let scaled_update = vec![Array1::from_vec(vec![0.1 * scale, 0.2 * scale])];
1557 analyzer.record_step(&scaled_grad, &scaled_update).unwrap();
1558 }
1559
1560 let viz_data = analyzer.export_for_visualization();
1561 assert_eq!(viz_data.step_indices.len(), 5);
1562 assert_eq!(viz_data.magnitude_series.len(), 1); assert_eq!(viz_data.magnitude_series[0].len(), 5); assert_eq!(viz_data.direction_similarities.len(), 5); let magnitudes = &viz_data.magnitude_series[0];
1568 for i in 1..magnitudes.len() {
1569 assert!(magnitudes[i] < magnitudes[i - 1]);
1570 }
1571 }
1572
1573 #[test]
1574 fn test_convergence_analysis() {
1575 let mut analyzer = GradientFlowAnalyzer::new(10);
1576
1577 for i in 0..10 {
1579 let scale = 1.0 / (i + 1) as f64;
1580 let gradients = vec![Array1::from_vec(vec![scale, scale])];
1581 let updates = vec![Array1::from_vec(vec![0.1 * scale, 0.1 * scale])];
1582 analyzer.record_step(&gradients, &updates).unwrap();
1583 }
1584
1585 let stats = analyzer.get_stats();
1586 assert!(stats.is_converging);
1587 assert!(stats.stability_score > 0.5);
1588 }
1589
1590 #[test]
1591 fn test_oscillation_detection() {
1592 let mut analyzer = GradientFlowAnalyzer::new(10);
1593
1594 let gradients = vec![Array1::from_vec(vec![1.0, 1.0])];
1596 let updates = vec![Array1::from_vec(vec![0.1, 0.1])];
1597 analyzer.record_step(&gradients, &updates).unwrap();
1598
1599 for i in 1..8 {
1601 let sign = if i % 2 == 0 { 1.0 } else { -1.0 };
1602 let gradients = vec![Array1::from_vec(vec![sign, sign])];
1603 let updates = vec![Array1::from_vec(vec![0.1 * sign, 0.1 * sign])];
1604 analyzer.record_step(&gradients, &updates).unwrap();
1605 }
1606
1607 let stats = analyzer.get_stats();
1608 assert!(stats.oscillation_frequency >= 0.0); }
1613
1614 #[test]
1615 fn test_optimizer_state_visualizer() {
1616 let mut visualizer = visualization::OptimizerStateVisualizer::new(100);
1617
1618 let params = vec![Array1::from_vec(vec![1.0, 2.0])];
1619 let state = visualization::OptimizerStateSnapshot::new()
1620 .with_custom_field("test_field".to_string(), 0.5);
1621
1622 visualizer.record_step(¶ms, state, 0.01, 1.5);
1623 visualizer.record_step(
1624 ¶ms,
1625 visualization::OptimizerStateSnapshot::new(),
1626 0.009,
1627 1.2,
1628 );
1629
1630 assert_eq!(visualizer.step_count(), 2);
1631
1632 let summary = visualizer.generate_state_summary();
1634 assert!(summary.contains("Total Steps: 2"));
1635 assert!(summary.contains("Current Loss: 1.200000"));
1636
1637 let plot = visualizer.generate_convergence_plot(40, 10);
1639 assert!(plot.contains("Loss Convergence"));
1640 assert!(plot.contains("Steps: 2"));
1641
1642 let lr_plot = visualizer.generate_learning_rate_plot(40, 10);
1644 assert!(lr_plot.contains("Learning Rate Schedule"));
1645
1646 let heatmap = visualizer.generate_parameter_heatmap(20, 5);
1648 assert!(heatmap.contains("Parameter Evolution Heatmap"));
1649 }
1650
1651 #[test]
1652 fn test_visualization_export() {
1653 let mut visualizer = visualization::OptimizerStateVisualizer::new(10);
1654
1655 for i in 0..5 {
1656 let params = vec![Array1::from_vec(vec![i as f64, (i * 2) as f64])];
1657 let state = visualization::OptimizerStateSnapshot::new();
1658 let lr = 0.01 / (i + 1) as f64;
1659 let loss = 1.0 / (i + 1) as f64;
1660
1661 visualizer.record_step(¶ms, state, lr, loss);
1662 }
1663
1664 let export = visualizer.export_data();
1665 assert_eq!(export.step_indices.len(), 5);
1666 assert_eq!(export.loss_history.len(), 5);
1667 assert_eq!(export.learning_rate_history.len(), 5);
1668 assert_eq!(export.parameter_norms.len(), 5);
1669 assert_eq!(export.state_snapshots.len(), 5);
1670
1671 assert!(export.loss_history[0] > export.loss_history[4]);
1673 assert!(export.learning_rate_history[0] > export.learning_rate_history[4]);
1674 }
1675
1676 #[test]
1677 fn test_optimizer_dashboard() {
1678 let mut dashboard = visualization::OptimizerDashboard::new();
1679
1680 dashboard.add_optimizer("SGD".to_string(), 100);
1681 dashboard.add_optimizer("Adam".to_string(), 100);
1682
1683 let params = vec![Array1::from_vec(vec![1.0, 2.0])];
1684 let state = visualization::OptimizerStateSnapshot::new();
1685
1686 dashboard
1688 .record_optimizer_step("SGD", ¶ms, state.clone(), 0.01, 1.0)
1689 .unwrap();
1690 dashboard
1691 .record_optimizer_step("Adam", ¶ms, state, 0.001, 0.8)
1692 .unwrap();
1693
1694 let optimizers = dashboard.list_optimizers();
1695 assert_eq!(optimizers.len(), 2);
1696 assert!(optimizers.contains(&&"SGD".to_string()));
1697 assert!(optimizers.contains(&&"Adam".to_string()));
1698
1699 let sgd_viz = dashboard.get_visualizer("SGD").unwrap();
1701 assert_eq!(sgd_viz.step_count(), 1);
1702
1703 let report = dashboard.generate_comparison_report();
1705 assert!(report.contains("Optimizer Comparison Dashboard"));
1706 assert!(report.contains("SGD"));
1707 assert!(report.contains("Adam"));
1708 assert!(report.contains("Lowest Current Loss: Adam"));
1709 }
1710
1711 #[test]
1712 fn test_state_snapshot_custom_fields() {
1713 let snapshot = visualization::OptimizerStateSnapshot::new()
1714 .with_custom_field("custom1".to_string(), 1.5)
1715 .with_custom_field("custom2".to_string(), 2.5);
1716
1717 assert_eq!(snapshot.custom_fields.len(), 2);
1718 assert_eq!(snapshot.custom_fields.get("custom1"), Some(&1.5));
1719 assert_eq!(snapshot.custom_fields.get("custom2"), Some(&2.5));
1720 }
1721
1722 #[test]
1723 fn test_visualizer_clear() {
1724 let mut visualizer = visualization::OptimizerStateVisualizer::new(10);
1725
1726 let params = vec![Array1::from_vec(vec![1.0])];
1727 let state = visualization::OptimizerStateSnapshot::new();
1728
1729 visualizer.record_step(¶ms, state, 0.01, 1.0);
1730 assert_eq!(visualizer.step_count(), 1);
1731
1732 visualizer.clear();
1733 assert_eq!(visualizer.step_count(), 0);
1734
1735 let summary = visualizer.generate_state_summary();
1736 assert!(summary.contains("Total Steps: 0"));
1737 }
1738
1739 #[test]
1740 fn test_dashboard_invalid_optimizer() {
1741 let mut dashboard = visualization::OptimizerDashboard::new();
1742
1743 let params = vec![Array1::from_vec(vec![1.0])];
1744 let state = visualization::OptimizerStateSnapshot::new();
1745
1746 let result = dashboard.record_optimizer_step("NonExistent", ¶ms, state, 0.01, 1.0);
1748 assert!(result.is_err());
1749 }
1750
1751 #[test]
1752 fn test_ascii_plot_generation() {
1753 let mut visualizer = visualization::OptimizerStateVisualizer::new(10);
1754
1755 for i in 0..10 {
1757 let params = vec![Array1::from_vec(vec![1.0])];
1758 let state = visualization::OptimizerStateSnapshot::new();
1759 let loss = 10.0 - i as f64; let lr = 0.1; visualizer.record_step(¶ms, state, lr, loss);
1763 }
1764
1765 let plot = visualizer.generate_convergence_plot(20, 5);
1767 let lines: Vec<&str> = plot.lines().collect();
1768 assert!(lines.len() > 5); assert!(plot.contains("|")); assert!(plot.contains("-")); assert!(plot.contains("*")); let lr_plot = visualizer.generate_learning_rate_plot(20, 5);
1777 assert!(lr_plot.contains("Learning Rate Schedule"));
1778 assert!(lr_plot.contains("Max: 0.100000, Min: 0.100000")); }
1780
1781 #[test]
1782 fn test_parameter_heatmap_generation() {
1783 let mut visualizer = visualization::OptimizerStateVisualizer::new(10);
1784
1785 for i in 0..5 {
1787 let params = vec![Array1::from_vec(vec![i as f64 * 0.1, i as f64 * 0.2])];
1788 let state = visualization::OptimizerStateSnapshot::new();
1789 visualizer.record_step(¶ms, state, 0.01, 1.0);
1790 }
1791
1792 let heatmap = visualizer.generate_parameter_heatmap(10, 5);
1793 assert!(heatmap.contains("Parameter Evolution Heatmap"));
1794 assert!(heatmap.contains("Legend"));
1795 assert!(heatmap.contains("Range"));
1796
1797 assert!(heatmap.contains("P"));
1799 }
1800}