1use scirs2_core::ndarray::{Array1, ArrayView1, ScalarOperand};
24use scirs2_core::numeric::Float;
25use std::collections::HashMap;
26use std::fmt::Debug;
27use std::time::{Duration, Instant};
28
29use crate::error::Result;
30
31#[derive(Debug, Clone)]
36pub struct OptimizerMetrics {
37 pub name: String,
39 pub step_count: u64,
41 pub total_step_time: Duration,
43 pub avg_step_time: Duration,
45 pub current_learning_rate: f64,
47 pub gradient_stats: GradientStatistics,
49 pub parameter_stats: ParameterStatistics,
51 pub convergence: ConvergenceMetrics,
53 pub memory_usage: usize,
55}
56
57impl OptimizerMetrics {
58 pub fn new(name: impl Into<String>) -> Self {
60 Self {
61 name: name.into(),
62 step_count: 0,
63 total_step_time: Duration::ZERO,
64 avg_step_time: Duration::ZERO,
65 current_learning_rate: 0.0,
66 gradient_stats: GradientStatistics::default(),
67 parameter_stats: ParameterStatistics::default(),
68 convergence: ConvergenceMetrics::default(),
69 memory_usage: 0,
70 }
71 }
72
73 pub fn update_step<A: Float>(
75 &mut self,
76 step_duration: Duration,
77 learning_rate: f64,
78 gradients: &ArrayView1<A>,
79 params_before: &ArrayView1<A>,
80 params_after: &ArrayView1<A>,
81 ) {
82 self.step_count += 1;
83 self.total_step_time += step_duration;
84 self.avg_step_time = self.total_step_time / self.step_count as u32;
85 self.current_learning_rate = learning_rate;
86
87 self.gradient_stats.update(gradients);
89
90 self.parameter_stats.update(params_before, params_after);
92
93 self.convergence.update(&self.parameter_stats);
95 }
96
97 pub fn throughput(&self) -> f64 {
99 if self.total_step_time.as_secs_f64() > 0.0 {
100 self.step_count as f64 / self.total_step_time.as_secs_f64()
101 } else {
102 0.0
103 }
104 }
105
106 pub fn reset(&mut self) {
108 self.step_count = 0;
109 self.total_step_time = Duration::ZERO;
110 self.avg_step_time = Duration::ZERO;
111 self.gradient_stats = GradientStatistics::default();
112 self.parameter_stats = ParameterStatistics::default();
113 self.convergence = ConvergenceMetrics::default();
114 }
115}
116
117#[derive(Debug, Clone, Default)]
119pub struct GradientStatistics {
120 pub mean: f64,
122 pub std_dev: f64,
124 pub max: f64,
126 pub min: f64,
128 pub norm: f64,
130 pub num_zeros: usize,
132}
133
134impl GradientStatistics {
135 pub fn update<A: Float>(&mut self, gradients: &ArrayView1<A>) {
137 let n = gradients.len();
138 if n == 0 {
139 return;
140 }
141
142 let sum: f64 = gradients.iter().map(|&g| g.to_f64().unwrap()).sum();
144 self.mean = sum / n as f64;
145
146 let variance: f64 = gradients
147 .iter()
148 .map(|&g| {
149 let diff = g.to_f64().unwrap() - self.mean;
150 diff * diff
151 })
152 .sum::<f64>()
153 / n as f64;
154 self.std_dev = variance.sqrt();
155
156 self.max = gradients
157 .iter()
158 .map(|&g| g.to_f64().unwrap())
159 .fold(f64::NEG_INFINITY, f64::max);
160 self.min = gradients
161 .iter()
162 .map(|&g| g.to_f64().unwrap())
163 .fold(f64::INFINITY, f64::min);
164
165 self.norm = gradients
166 .iter()
167 .map(|&g| {
168 let val = g.to_f64().unwrap();
169 val * val
170 })
171 .sum::<f64>()
172 .sqrt();
173
174 self.num_zeros = gradients
175 .iter()
176 .filter(|&&g| g.to_f64().unwrap().abs() < 1e-10)
177 .count();
178 }
179}
180
181#[derive(Debug, Clone, Default)]
183pub struct ParameterStatistics {
184 pub mean: f64,
186 pub std_dev: f64,
188 pub update_magnitude: f64,
190 pub relative_change: f64,
192}
193
194impl ParameterStatistics {
195 pub fn update<A: Float>(
197 &mut self,
198 params_before: &ArrayView1<A>,
199 params_after: &ArrayView1<A>,
200 ) {
201 let n = params_after.len();
202 if n == 0 {
203 return;
204 }
205
206 let sum: f64 = params_after.iter().map(|&p| p.to_f64().unwrap()).sum();
208 self.mean = sum / n as f64;
209
210 let variance: f64 = params_after
212 .iter()
213 .map(|&p| {
214 let diff = p.to_f64().unwrap() - self.mean;
215 diff * diff
216 })
217 .sum::<f64>()
218 / n as f64;
219 self.std_dev = variance.sqrt();
220
221 self.update_magnitude = params_before
223 .iter()
224 .zip(params_after.iter())
225 .map(|(&before, &after)| {
226 let diff = after.to_f64().unwrap() - before.to_f64().unwrap();
227 diff * diff
228 })
229 .sum::<f64>()
230 .sqrt();
231
232 let params_norm: f64 = params_before
234 .iter()
235 .map(|&p| {
236 let val = p.to_f64().unwrap();
237 val * val
238 })
239 .sum::<f64>()
240 .sqrt();
241
242 self.relative_change = if params_norm > 1e-10 {
243 self.update_magnitude / params_norm
244 } else {
245 0.0
246 };
247 }
248}
249
250#[derive(Debug, Clone, Default)]
252pub struct ConvergenceMetrics {
253 pub update_moving_avg: f64,
255 pub is_converging: bool,
257 pub estimated_steps_to_convergence: Option<u64>,
259 pub convergence_rate: f64,
261}
262
263impl ConvergenceMetrics {
264 pub fn update(&mut self, param_stats: &ParameterStatistics) {
266 if self.update_moving_avg > 1e-10 {
268 self.is_converging = param_stats.update_magnitude < self.update_moving_avg;
269 self.convergence_rate = 1.0 - (param_stats.update_magnitude / self.update_moving_avg);
270 }
271
272 let alpha = 0.1;
274 self.update_moving_avg =
275 alpha * param_stats.update_magnitude + (1.0 - alpha) * self.update_moving_avg;
276 }
277}
278
279pub struct MetricsCollector {
281 metrics: HashMap<String, OptimizerMetrics>,
283 start_time: Instant,
285}
286
287impl MetricsCollector {
288 pub fn new() -> Self {
290 Self {
291 metrics: HashMap::new(),
292 start_time: Instant::now(),
293 }
294 }
295
296 pub fn register_optimizer(&mut self, name: impl Into<String>) {
298 let name = name.into();
299 self.metrics
300 .entry(name.clone())
301 .or_insert_with(|| OptimizerMetrics::new(name));
302 }
303
304 pub fn update<A: Float + ScalarOperand>(
306 &mut self,
307 optimizer_name: &str,
308 step_duration: Duration,
309 learning_rate: f64,
310 gradients: &ArrayView1<A>,
311 params_before: &ArrayView1<A>,
312 params_after: &ArrayView1<A>,
313 ) -> Result<()> {
314 if let Some(metrics) = self.metrics.get_mut(optimizer_name) {
315 metrics.update_step(
316 step_duration,
317 learning_rate,
318 gradients,
319 params_before,
320 params_after,
321 );
322 Ok(())
323 } else {
324 Err(crate::error::OptimError::InvalidConfig(format!(
325 "Optimizer '{}' not registered",
326 optimizer_name
327 )))
328 }
329 }
330
331 pub fn get_metrics(&self, optimizer_name: &str) -> Option<&OptimizerMetrics> {
333 self.metrics.get(optimizer_name)
334 }
335
336 pub fn all_metrics(&self) -> &HashMap<String, OptimizerMetrics> {
338 &self.metrics
339 }
340
341 pub fn elapsed(&self) -> Duration {
343 self.start_time.elapsed()
344 }
345
346 pub fn reset(&mut self) {
348 for metrics in self.metrics.values_mut() {
349 metrics.reset();
350 }
351 self.start_time = Instant::now();
352 }
353
354 pub fn summary_report(&self) -> String {
356 let mut report = String::new();
357 report.push_str("=== Optimizer Metrics Summary ===\n");
358 report.push_str(&format!("Total elapsed time: {:?}\n\n", self.elapsed()));
359
360 for (name, metrics) in &self.metrics {
361 report.push_str(&format!("Optimizer: {}\n", name));
362 report.push_str(&format!(" Steps: {}\n", metrics.step_count));
363 report.push_str(&format!(" Avg step time: {:?}\n", metrics.avg_step_time));
364 report.push_str(&format!(
365 " Throughput: {:.2} steps/sec\n",
366 metrics.throughput()
367 ));
368 report.push_str(&format!(
369 " Learning rate: {:.6}\n",
370 metrics.current_learning_rate
371 ));
372 report.push_str(&format!(
373 " Gradient norm: {:.6}\n",
374 metrics.gradient_stats.norm
375 ));
376 report.push_str(&format!(
377 " Update magnitude: {:.6}\n",
378 metrics.parameter_stats.update_magnitude
379 ));
380 report.push_str(&format!(
381 " Converging: {}\n",
382 metrics.convergence.is_converging
383 ));
384 report.push_str(&format!(
385 " Memory usage: {} bytes\n\n",
386 metrics.memory_usage
387 ));
388 }
389
390 report
391 }
392}
393
394impl Default for MetricsCollector {
395 fn default() -> Self {
396 Self::new()
397 }
398}
399
400pub struct MetricsReporter;
402
403impl MetricsReporter {
404 pub fn to_json(metrics: &OptimizerMetrics) -> String {
406 format!(
407 r#"{{
408 "name": "{}",
409 "step_count": {},
410 "avg_step_time_ms": {},
411 "throughput": {},
412 "learning_rate": {},
413 "gradient_norm": {},
414 "update_magnitude": {},
415 "is_converging": {}
416}}"#,
417 metrics.name,
418 metrics.step_count,
419 metrics.avg_step_time.as_millis(),
420 metrics.throughput(),
421 metrics.current_learning_rate,
422 metrics.gradient_stats.norm,
423 metrics.parameter_stats.update_magnitude,
424 metrics.convergence.is_converging
425 )
426 }
427
428 pub fn to_csv_header() -> String {
430 "name,step_count,avg_step_time_ms,throughput,learning_rate,gradient_norm,update_magnitude,is_converging".to_string()
431 }
432
433 pub fn to_csv(metrics: &OptimizerMetrics) -> String {
435 format!(
436 "{},{},{},{},{},{},{},{}",
437 metrics.name,
438 metrics.step_count,
439 metrics.avg_step_time.as_millis(),
440 metrics.throughput(),
441 metrics.current_learning_rate,
442 metrics.gradient_stats.norm,
443 metrics.parameter_stats.update_magnitude,
444 metrics.convergence.is_converging
445 )
446 }
447}
448
449#[cfg(test)]
450mod tests {
451 use super::*;
452 use scirs2_core::ndarray::Array1;
453
454 #[test]
455 fn test_optimizer_metrics_creation() {
456 let metrics = OptimizerMetrics::new("sgd");
457 assert_eq!(metrics.name, "sgd");
458 assert_eq!(metrics.step_count, 0);
459 assert_eq!(metrics.throughput(), 0.0);
460 }
461
462 #[test]
463 fn test_gradient_statistics() {
464 let mut stats = GradientStatistics::default();
465 let grads = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
466 stats.update(&grads.view());
467
468 assert!((stats.mean - 3.0).abs() < 1e-6);
469 assert!(stats.max > 4.9);
470 assert!(stats.min < 1.1);
471 assert!(stats.norm > 0.0);
472 }
473
474 #[test]
475 fn test_parameter_statistics() {
476 let mut stats = ParameterStatistics::default();
477 let before = Array1::from_vec(vec![1.0, 2.0, 3.0]);
478 let after = Array1::from_vec(vec![0.9, 1.9, 2.9]);
479 stats.update(&before.view(), &after.view());
480
481 assert!(stats.update_magnitude > 0.0);
482 assert!(stats.relative_change > 0.0);
483 assert!((stats.mean - 1.9).abs() < 1e-6);
484 }
485
486 #[test]
487 fn test_metrics_collector() {
488 let mut collector = MetricsCollector::new();
489 collector.register_optimizer("sgd");
490
491 let grads = Array1::from_vec(vec![0.1, 0.2, 0.3]);
492 let before = Array1::from_vec(vec![1.0, 2.0, 3.0]);
493 let after = Array1::from_vec(vec![0.99, 1.98, 2.97]);
494
495 let result = collector.update(
496 "sgd",
497 Duration::from_millis(10),
498 0.01,
499 &grads.view(),
500 &before.view(),
501 &after.view(),
502 );
503
504 assert!(result.is_ok());
505 let metrics = collector.get_metrics("sgd").unwrap();
506 assert_eq!(metrics.step_count, 1);
507 }
508
509 #[test]
510 fn test_metrics_collector_multiple_updates() {
511 let mut collector = MetricsCollector::new();
512 collector.register_optimizer("adam");
513
514 let grads = Array1::from_vec(vec![0.1, 0.2]);
515 let before = Array1::from_vec(vec![1.0, 2.0]);
516 let after = Array1::from_vec(vec![0.99, 1.98]);
517
518 for _ in 0..10 {
519 collector
520 .update(
521 "adam",
522 Duration::from_millis(5),
523 0.001,
524 &grads.view(),
525 &before.view(),
526 &after.view(),
527 )
528 .unwrap();
529 }
530
531 let metrics = collector.get_metrics("adam").unwrap();
532 assert_eq!(metrics.step_count, 10);
533 assert!(metrics.throughput() > 0.0);
534 }
535
536 #[test]
537 fn test_metrics_reset() {
538 let mut metrics = OptimizerMetrics::new("test");
539 let grads = Array1::from_vec(vec![0.1]);
540 let before = Array1::from_vec(vec![1.0]);
541 let after = Array1::from_vec(vec![0.99]);
542
543 metrics.update_step(
544 Duration::from_millis(10),
545 0.01,
546 &grads.view(),
547 &before.view(),
548 &after.view(),
549 );
550
551 assert_eq!(metrics.step_count, 1);
552
553 metrics.reset();
554 assert_eq!(metrics.step_count, 0);
555 assert_eq!(metrics.total_step_time, Duration::ZERO);
556 }
557
558 #[test]
559 fn test_summary_report() {
560 let mut collector = MetricsCollector::new();
561 collector.register_optimizer("sgd");
562
563 let grads = Array1::from_vec(vec![0.1]);
564 let before = Array1::from_vec(vec![1.0]);
565 let after = Array1::from_vec(vec![0.99]);
566
567 collector
568 .update(
569 "sgd",
570 Duration::from_millis(10),
571 0.01,
572 &grads.view(),
573 &before.view(),
574 &after.view(),
575 )
576 .unwrap();
577
578 let report = collector.summary_report();
579 assert!(report.contains("Optimizer: sgd"));
580 assert!(report.contains("Steps: 1"));
581 }
582
583 #[test]
584 fn test_metrics_reporter_json() {
585 let metrics = OptimizerMetrics::new("test");
586 let json = MetricsReporter::to_json(&metrics);
587 assert!(json.contains("\"name\": \"test\""));
588 assert!(json.contains("\"step_count\": 0"));
589 }
590
591 #[test]
592 fn test_metrics_reporter_csv() {
593 let metrics = OptimizerMetrics::new("test");
594 let header = MetricsReporter::to_csv_header();
595 let row = MetricsReporter::to_csv(&metrics);
596
597 assert!(header.contains("name"));
598 assert!(header.contains("step_count"));
599 assert!(row.starts_with("test,0,"));
600 }
601
602 #[test]
603 fn test_convergence_metrics() {
604 let mut convergence = ConvergenceMetrics::default();
605
606 let mut param_stats = ParameterStatistics {
608 update_magnitude: 1.0,
609 ..Default::default()
610 };
611 convergence.update(¶m_stats);
612 assert_eq!(convergence.update_moving_avg, 0.1);
613
614 param_stats.update_magnitude = 0.5;
615 convergence.update(¶m_stats);
616 assert!((convergence.update_moving_avg - 0.14).abs() < 1e-6);
618
619 param_stats.update_magnitude = 0.05;
621 convergence.update(¶m_stats);
622 assert!(convergence.is_converging);
624 assert!(convergence.update_moving_avg > 0.0);
625 }
626}