1use scirs2_core::ndarray::{Array1, ArrayView1, ScalarOperand};
24use scirs2_core::numeric::Float;
25use std::collections::HashMap;
26use std::fmt::Debug;
27use std::time::{Duration, Instant};
28
29use crate::error::Result;
30
31#[derive(Debug, Clone)]
36pub struct OptimizerMetrics {
37 pub name: String,
39 pub step_count: u64,
41 pub total_step_time: Duration,
43 pub avg_step_time: Duration,
45 pub current_learning_rate: f64,
47 pub gradient_stats: GradientStatistics,
49 pub parameter_stats: ParameterStatistics,
51 pub convergence: ConvergenceMetrics,
53 pub memory_usage: usize,
55}
56
57impl OptimizerMetrics {
58 pub fn new(name: impl Into<String>) -> Self {
60 Self {
61 name: name.into(),
62 step_count: 0,
63 total_step_time: Duration::ZERO,
64 avg_step_time: Duration::ZERO,
65 current_learning_rate: 0.0,
66 gradient_stats: GradientStatistics::default(),
67 parameter_stats: ParameterStatistics::default(),
68 convergence: ConvergenceMetrics::default(),
69 memory_usage: 0,
70 }
71 }
72
73 pub fn update_step<A: Float>(
75 &mut self,
76 step_duration: Duration,
77 learning_rate: f64,
78 gradients: &ArrayView1<A>,
79 params_before: &ArrayView1<A>,
80 params_after: &ArrayView1<A>,
81 ) {
82 self.step_count += 1;
83 self.total_step_time += step_duration;
84 self.avg_step_time = self.total_step_time / self.step_count as u32;
85 self.current_learning_rate = learning_rate;
86
87 self.gradient_stats.update(gradients);
89
90 self.parameter_stats.update(params_before, params_after);
92
93 self.convergence.update(&self.parameter_stats);
95 }
96
97 pub fn throughput(&self) -> f64 {
99 if self.total_step_time.as_secs_f64() > 0.0 {
100 self.step_count as f64 / self.total_step_time.as_secs_f64()
101 } else {
102 0.0
103 }
104 }
105
106 pub fn reset(&mut self) {
108 self.step_count = 0;
109 self.total_step_time = Duration::ZERO;
110 self.avg_step_time = Duration::ZERO;
111 self.gradient_stats = GradientStatistics::default();
112 self.parameter_stats = ParameterStatistics::default();
113 self.convergence = ConvergenceMetrics::default();
114 }
115}
116
117#[derive(Debug, Clone, Default)]
119pub struct GradientStatistics {
120 pub mean: f64,
122 pub std_dev: f64,
124 pub max: f64,
126 pub min: f64,
128 pub norm: f64,
130 pub num_zeros: usize,
132}
133
134impl GradientStatistics {
135 pub fn update<A: Float>(&mut self, gradients: &ArrayView1<A>) {
137 let n = gradients.len();
138 if n == 0 {
139 return;
140 }
141
142 let sum: f64 = gradients
144 .iter()
145 .map(|&g| g.to_f64().expect("unwrap failed"))
146 .sum();
147 self.mean = sum / n as f64;
148
149 let variance: f64 = gradients
150 .iter()
151 .map(|&g| {
152 let diff = g.to_f64().expect("unwrap failed") - self.mean;
153 diff * diff
154 })
155 .sum::<f64>()
156 / n as f64;
157 self.std_dev = variance.sqrt();
158
159 self.max = gradients
160 .iter()
161 .map(|&g| g.to_f64().expect("unwrap failed"))
162 .fold(f64::NEG_INFINITY, f64::max);
163 self.min = gradients
164 .iter()
165 .map(|&g| g.to_f64().expect("unwrap failed"))
166 .fold(f64::INFINITY, f64::min);
167
168 self.norm = gradients
169 .iter()
170 .map(|&g| {
171 let val = g.to_f64().expect("unwrap failed");
172 val * val
173 })
174 .sum::<f64>()
175 .sqrt();
176
177 self.num_zeros = gradients
178 .iter()
179 .filter(|&&g| g.to_f64().expect("unwrap failed").abs() < 1e-10)
180 .count();
181 }
182}
183
184#[derive(Debug, Clone, Default)]
186pub struct ParameterStatistics {
187 pub mean: f64,
189 pub std_dev: f64,
191 pub update_magnitude: f64,
193 pub relative_change: f64,
195}
196
197impl ParameterStatistics {
198 pub fn update<A: Float>(
200 &mut self,
201 params_before: &ArrayView1<A>,
202 params_after: &ArrayView1<A>,
203 ) {
204 let n = params_after.len();
205 if n == 0 {
206 return;
207 }
208
209 let sum: f64 = params_after
211 .iter()
212 .map(|&p| p.to_f64().expect("unwrap failed"))
213 .sum();
214 self.mean = sum / n as f64;
215
216 let variance: f64 = params_after
218 .iter()
219 .map(|&p| {
220 let diff = p.to_f64().expect("unwrap failed") - self.mean;
221 diff * diff
222 })
223 .sum::<f64>()
224 / n as f64;
225 self.std_dev = variance.sqrt();
226
227 self.update_magnitude = params_before
229 .iter()
230 .zip(params_after.iter())
231 .map(|(&before, &after)| {
232 let diff = after.to_f64().expect("unwrap failed")
233 - before.to_f64().expect("unwrap failed");
234 diff * diff
235 })
236 .sum::<f64>()
237 .sqrt();
238
239 let params_norm: f64 = params_before
241 .iter()
242 .map(|&p| {
243 let val = p.to_f64().expect("unwrap failed");
244 val * val
245 })
246 .sum::<f64>()
247 .sqrt();
248
249 self.relative_change = if params_norm > 1e-10 {
250 self.update_magnitude / params_norm
251 } else {
252 0.0
253 };
254 }
255}
256
257#[derive(Debug, Clone, Default)]
259pub struct ConvergenceMetrics {
260 pub update_moving_avg: f64,
262 pub is_converging: bool,
264 pub estimated_steps_to_convergence: Option<u64>,
266 pub convergence_rate: f64,
268}
269
270impl ConvergenceMetrics {
271 pub fn update(&mut self, param_stats: &ParameterStatistics) {
273 if self.update_moving_avg > 1e-10 {
275 self.is_converging = param_stats.update_magnitude < self.update_moving_avg;
276 self.convergence_rate = 1.0 - (param_stats.update_magnitude / self.update_moving_avg);
277 }
278
279 let alpha = 0.1;
281 self.update_moving_avg =
282 alpha * param_stats.update_magnitude + (1.0 - alpha) * self.update_moving_avg;
283 }
284}
285
286pub struct MetricsCollector {
288 metrics: HashMap<String, OptimizerMetrics>,
290 start_time: Instant,
292}
293
294impl MetricsCollector {
295 pub fn new() -> Self {
297 Self {
298 metrics: HashMap::new(),
299 start_time: Instant::now(),
300 }
301 }
302
303 pub fn register_optimizer(&mut self, name: impl Into<String>) {
305 let name = name.into();
306 self.metrics
307 .entry(name.clone())
308 .or_insert_with(|| OptimizerMetrics::new(name));
309 }
310
311 pub fn update<A: Float + ScalarOperand>(
313 &mut self,
314 optimizer_name: &str,
315 step_duration: Duration,
316 learning_rate: f64,
317 gradients: &ArrayView1<A>,
318 params_before: &ArrayView1<A>,
319 params_after: &ArrayView1<A>,
320 ) -> Result<()> {
321 if let Some(metrics) = self.metrics.get_mut(optimizer_name) {
322 metrics.update_step(
323 step_duration,
324 learning_rate,
325 gradients,
326 params_before,
327 params_after,
328 );
329 Ok(())
330 } else {
331 Err(crate::error::OptimError::InvalidConfig(format!(
332 "Optimizer '{}' not registered",
333 optimizer_name
334 )))
335 }
336 }
337
338 pub fn get_metrics(&self, optimizer_name: &str) -> Option<&OptimizerMetrics> {
340 self.metrics.get(optimizer_name)
341 }
342
343 pub fn all_metrics(&self) -> &HashMap<String, OptimizerMetrics> {
345 &self.metrics
346 }
347
348 pub fn elapsed(&self) -> Duration {
350 self.start_time.elapsed()
351 }
352
353 pub fn reset(&mut self) {
355 for metrics in self.metrics.values_mut() {
356 metrics.reset();
357 }
358 self.start_time = Instant::now();
359 }
360
361 pub fn summary_report(&self) -> String {
363 let mut report = String::new();
364 report.push_str("=== Optimizer Metrics Summary ===\n");
365 report.push_str(&format!("Total elapsed time: {:?}\n\n", self.elapsed()));
366
367 for (name, metrics) in &self.metrics {
368 report.push_str(&format!("Optimizer: {}\n", name));
369 report.push_str(&format!(" Steps: {}\n", metrics.step_count));
370 report.push_str(&format!(" Avg step time: {:?}\n", metrics.avg_step_time));
371 report.push_str(&format!(
372 " Throughput: {:.2} steps/sec\n",
373 metrics.throughput()
374 ));
375 report.push_str(&format!(
376 " Learning rate: {:.6}\n",
377 metrics.current_learning_rate
378 ));
379 report.push_str(&format!(
380 " Gradient norm: {:.6}\n",
381 metrics.gradient_stats.norm
382 ));
383 report.push_str(&format!(
384 " Update magnitude: {:.6}\n",
385 metrics.parameter_stats.update_magnitude
386 ));
387 report.push_str(&format!(
388 " Converging: {}\n",
389 metrics.convergence.is_converging
390 ));
391 report.push_str(&format!(
392 " Memory usage: {} bytes\n\n",
393 metrics.memory_usage
394 ));
395 }
396
397 report
398 }
399}
400
401impl Default for MetricsCollector {
402 fn default() -> Self {
403 Self::new()
404 }
405}
406
407pub struct MetricsReporter;
409
410impl MetricsReporter {
411 pub fn to_json(metrics: &OptimizerMetrics) -> String {
413 format!(
414 r#"{{
415 "name": "{}",
416 "step_count": {},
417 "avg_step_time_ms": {},
418 "throughput": {},
419 "learning_rate": {},
420 "gradient_norm": {},
421 "update_magnitude": {},
422 "is_converging": {}
423}}"#,
424 metrics.name,
425 metrics.step_count,
426 metrics.avg_step_time.as_millis(),
427 metrics.throughput(),
428 metrics.current_learning_rate,
429 metrics.gradient_stats.norm,
430 metrics.parameter_stats.update_magnitude,
431 metrics.convergence.is_converging
432 )
433 }
434
435 pub fn to_csv_header() -> String {
437 "name,step_count,avg_step_time_ms,throughput,learning_rate,gradient_norm,update_magnitude,is_converging".to_string()
438 }
439
440 pub fn to_csv(metrics: &OptimizerMetrics) -> String {
442 format!(
443 "{},{},{},{},{},{},{},{}",
444 metrics.name,
445 metrics.step_count,
446 metrics.avg_step_time.as_millis(),
447 metrics.throughput(),
448 metrics.current_learning_rate,
449 metrics.gradient_stats.norm,
450 metrics.parameter_stats.update_magnitude,
451 metrics.convergence.is_converging
452 )
453 }
454}
455
456#[cfg(test)]
457mod tests {
458 use super::*;
459 use scirs2_core::ndarray::Array1;
460
461 #[test]
462 fn test_optimizer_metrics_creation() {
463 let metrics = OptimizerMetrics::new("sgd");
464 assert_eq!(metrics.name, "sgd");
465 assert_eq!(metrics.step_count, 0);
466 assert_eq!(metrics.throughput(), 0.0);
467 }
468
469 #[test]
470 fn test_gradient_statistics() {
471 let mut stats = GradientStatistics::default();
472 let grads = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
473 stats.update(&grads.view());
474
475 assert!((stats.mean - 3.0).abs() < 1e-6);
476 assert!(stats.max > 4.9);
477 assert!(stats.min < 1.1);
478 assert!(stats.norm > 0.0);
479 }
480
481 #[test]
482 fn test_parameter_statistics() {
483 let mut stats = ParameterStatistics::default();
484 let before = Array1::from_vec(vec![1.0, 2.0, 3.0]);
485 let after = Array1::from_vec(vec![0.9, 1.9, 2.9]);
486 stats.update(&before.view(), &after.view());
487
488 assert!(stats.update_magnitude > 0.0);
489 assert!(stats.relative_change > 0.0);
490 assert!((stats.mean - 1.9).abs() < 1e-6);
491 }
492
493 #[test]
494 fn test_metrics_collector() {
495 let mut collector = MetricsCollector::new();
496 collector.register_optimizer("sgd");
497
498 let grads = Array1::from_vec(vec![0.1, 0.2, 0.3]);
499 let before = Array1::from_vec(vec![1.0, 2.0, 3.0]);
500 let after = Array1::from_vec(vec![0.99, 1.98, 2.97]);
501
502 let result = collector.update(
503 "sgd",
504 Duration::from_millis(10),
505 0.01,
506 &grads.view(),
507 &before.view(),
508 &after.view(),
509 );
510
511 assert!(result.is_ok());
512 let metrics = collector.get_metrics("sgd").expect("unwrap failed");
513 assert_eq!(metrics.step_count, 1);
514 }
515
516 #[test]
517 fn test_metrics_collector_multiple_updates() {
518 let mut collector = MetricsCollector::new();
519 collector.register_optimizer("adam");
520
521 let grads = Array1::from_vec(vec![0.1, 0.2]);
522 let before = Array1::from_vec(vec![1.0, 2.0]);
523 let after = Array1::from_vec(vec![0.99, 1.98]);
524
525 for _ in 0..10 {
526 collector
527 .update(
528 "adam",
529 Duration::from_millis(5),
530 0.001,
531 &grads.view(),
532 &before.view(),
533 &after.view(),
534 )
535 .expect("unwrap failed");
536 }
537
538 let metrics = collector.get_metrics("adam").expect("unwrap failed");
539 assert_eq!(metrics.step_count, 10);
540 assert!(metrics.throughput() > 0.0);
541 }
542
543 #[test]
544 fn test_metrics_reset() {
545 let mut metrics = OptimizerMetrics::new("test");
546 let grads = Array1::from_vec(vec![0.1]);
547 let before = Array1::from_vec(vec![1.0]);
548 let after = Array1::from_vec(vec![0.99]);
549
550 metrics.update_step(
551 Duration::from_millis(10),
552 0.01,
553 &grads.view(),
554 &before.view(),
555 &after.view(),
556 );
557
558 assert_eq!(metrics.step_count, 1);
559
560 metrics.reset();
561 assert_eq!(metrics.step_count, 0);
562 assert_eq!(metrics.total_step_time, Duration::ZERO);
563 }
564
565 #[test]
566 fn test_summary_report() {
567 let mut collector = MetricsCollector::new();
568 collector.register_optimizer("sgd");
569
570 let grads = Array1::from_vec(vec![0.1]);
571 let before = Array1::from_vec(vec![1.0]);
572 let after = Array1::from_vec(vec![0.99]);
573
574 collector
575 .update(
576 "sgd",
577 Duration::from_millis(10),
578 0.01,
579 &grads.view(),
580 &before.view(),
581 &after.view(),
582 )
583 .expect("unwrap failed");
584
585 let report = collector.summary_report();
586 assert!(report.contains("Optimizer: sgd"));
587 assert!(report.contains("Steps: 1"));
588 }
589
590 #[test]
591 fn test_metrics_reporter_json() {
592 let metrics = OptimizerMetrics::new("test");
593 let json = MetricsReporter::to_json(&metrics);
594 assert!(json.contains("\"name\": \"test\""));
595 assert!(json.contains("\"step_count\": 0"));
596 }
597
598 #[test]
599 fn test_metrics_reporter_csv() {
600 let metrics = OptimizerMetrics::new("test");
601 let header = MetricsReporter::to_csv_header();
602 let row = MetricsReporter::to_csv(&metrics);
603
604 assert!(header.contains("name"));
605 assert!(header.contains("step_count"));
606 assert!(row.starts_with("test,0,"));
607 }
608
609 #[test]
610 fn test_convergence_metrics() {
611 let mut convergence = ConvergenceMetrics::default();
612
613 let mut param_stats = ParameterStatistics {
615 update_magnitude: 1.0,
616 ..Default::default()
617 };
618 convergence.update(¶m_stats);
619 assert_eq!(convergence.update_moving_avg, 0.1);
620
621 param_stats.update_magnitude = 0.5;
622 convergence.update(¶m_stats);
623 assert!((convergence.update_moving_avg - 0.14).abs() < 1e-6);
625
626 param_stats.update_magnitude = 0.05;
628 convergence.update(¶m_stats);
629 assert!(convergence.is_converging);
631 assert!(convergence.update_moving_avg > 0.0);
632 }
633}