Skip to main content

optirs_core/
optimizer_metrics.rs

1//! Optimizer performance metrics and monitoring
2//!
3//! This module provides comprehensive metrics collection and monitoring for optimizers
4//! using SciRS2's metrics infrastructure for production deployments.
5//!
6//! # Features
7//!
8//! - Real-time optimizer performance tracking
9//! - Gradient and parameter statistics
10//! - Convergence monitoring
11//! - Memory usage tracking
12//! - Performance dashboards and reporting
13//!
14//! # SciRS2 Integration
15//!
16//! This module uses SciRS2-Core metrics abstractions exclusively:
17//! - `scirs2_core::metrics::MetricRegistry` for metric registration
18//! - `scirs2_core::metrics::Counter` for counting operations
19//! - `scirs2_core::metrics::Gauge` for current values
20//! - `scirs2_core::metrics::Histogram` for distributions
21//! - `scirs2_core::metrics::Timer` for timing operations
22
23use scirs2_core::ndarray::{Array1, ArrayView1, ScalarOperand};
24use scirs2_core::numeric::Float;
25use std::collections::HashMap;
26use std::fmt::Debug;
27use std::time::{Duration, Instant};
28
29use crate::error::Result;
30
31/// Optimizer performance metrics
32///
33/// Tracks key performance indicators for optimizer operations including
34/// step timing, gradient statistics, parameter updates, and convergence.
35#[derive(Debug, Clone)]
36pub struct OptimizerMetrics {
37    /// Optimizer name
38    pub name: String,
39    /// Total number of optimization steps
40    pub step_count: u64,
41    /// Total time spent in optimization steps
42    pub total_step_time: Duration,
43    /// Average time per step
44    pub avg_step_time: Duration,
45    /// Current learning rate
46    pub current_learning_rate: f64,
47    /// Gradient statistics
48    pub gradient_stats: GradientStatistics,
49    /// Parameter statistics
50    pub parameter_stats: ParameterStatistics,
51    /// Convergence metrics
52    pub convergence: ConvergenceMetrics,
53    /// Memory usage (bytes)
54    pub memory_usage: usize,
55}
56
57impl OptimizerMetrics {
58    /// Create new metrics for an optimizer
59    pub fn new(name: impl Into<String>) -> Self {
60        Self {
61            name: name.into(),
62            step_count: 0,
63            total_step_time: Duration::ZERO,
64            avg_step_time: Duration::ZERO,
65            current_learning_rate: 0.0,
66            gradient_stats: GradientStatistics::default(),
67            parameter_stats: ParameterStatistics::default(),
68            convergence: ConvergenceMetrics::default(),
69            memory_usage: 0,
70        }
71    }
72
73    /// Update metrics after an optimization step
74    pub fn update_step<A: Float>(
75        &mut self,
76        step_duration: Duration,
77        learning_rate: f64,
78        gradients: &ArrayView1<A>,
79        params_before: &ArrayView1<A>,
80        params_after: &ArrayView1<A>,
81    ) {
82        self.step_count += 1;
83        self.total_step_time += step_duration;
84        self.avg_step_time = self.total_step_time / self.step_count as u32;
85        self.current_learning_rate = learning_rate;
86
87        // Update gradient statistics
88        self.gradient_stats.update(gradients);
89
90        // Update parameter statistics
91        self.parameter_stats.update(params_before, params_after);
92
93        // Update convergence metrics
94        self.convergence.update(&self.parameter_stats);
95    }
96
97    /// Get throughput (steps per second)
98    pub fn throughput(&self) -> f64 {
99        if self.total_step_time.as_secs_f64() > 0.0 {
100            self.step_count as f64 / self.total_step_time.as_secs_f64()
101        } else {
102            0.0
103        }
104    }
105
106    /// Reset all metrics
107    pub fn reset(&mut self) {
108        self.step_count = 0;
109        self.total_step_time = Duration::ZERO;
110        self.avg_step_time = Duration::ZERO;
111        self.gradient_stats = GradientStatistics::default();
112        self.parameter_stats = ParameterStatistics::default();
113        self.convergence = ConvergenceMetrics::default();
114    }
115}
116
117/// Gradient statistics
118#[derive(Debug, Clone, Default)]
119pub struct GradientStatistics {
120    /// Mean gradient magnitude
121    pub mean: f64,
122    /// Standard deviation of gradients
123    pub std_dev: f64,
124    /// Maximum gradient value
125    pub max: f64,
126    /// Minimum gradient value
127    pub min: f64,
128    /// Gradient norm (L2)
129    pub norm: f64,
130    /// Number of zero gradients
131    pub num_zeros: usize,
132}
133
134impl GradientStatistics {
135    /// Update gradient statistics
136    pub fn update<A: Float>(&mut self, gradients: &ArrayView1<A>) {
137        let n = gradients.len();
138        if n == 0 {
139            return;
140        }
141
142        // Calculate statistics
143        let sum: f64 = gradients
144            .iter()
145            .map(|&g| g.to_f64().expect("unwrap failed"))
146            .sum();
147        self.mean = sum / n as f64;
148
149        let variance: f64 = gradients
150            .iter()
151            .map(|&g| {
152                let diff = g.to_f64().expect("unwrap failed") - self.mean;
153                diff * diff
154            })
155            .sum::<f64>()
156            / n as f64;
157        self.std_dev = variance.sqrt();
158
159        self.max = gradients
160            .iter()
161            .map(|&g| g.to_f64().expect("unwrap failed"))
162            .fold(f64::NEG_INFINITY, f64::max);
163        self.min = gradients
164            .iter()
165            .map(|&g| g.to_f64().expect("unwrap failed"))
166            .fold(f64::INFINITY, f64::min);
167
168        self.norm = gradients
169            .iter()
170            .map(|&g| {
171                let val = g.to_f64().expect("unwrap failed");
172                val * val
173            })
174            .sum::<f64>()
175            .sqrt();
176
177        self.num_zeros = gradients
178            .iter()
179            .filter(|&&g| g.to_f64().expect("unwrap failed").abs() < 1e-10)
180            .count();
181    }
182}
183
184/// Parameter statistics
185#[derive(Debug, Clone, Default)]
186pub struct ParameterStatistics {
187    /// Mean parameter value
188    pub mean: f64,
189    /// Standard deviation of parameters
190    pub std_dev: f64,
191    /// Parameter update magnitude
192    pub update_magnitude: f64,
193    /// Relative parameter change
194    pub relative_change: f64,
195}
196
197impl ParameterStatistics {
198    /// Update parameter statistics
199    pub fn update<A: Float>(
200        &mut self,
201        params_before: &ArrayView1<A>,
202        params_after: &ArrayView1<A>,
203    ) {
204        let n = params_after.len();
205        if n == 0 {
206            return;
207        }
208
209        // Calculate mean
210        let sum: f64 = params_after
211            .iter()
212            .map(|&p| p.to_f64().expect("unwrap failed"))
213            .sum();
214        self.mean = sum / n as f64;
215
216        // Calculate std dev
217        let variance: f64 = params_after
218            .iter()
219            .map(|&p| {
220                let diff = p.to_f64().expect("unwrap failed") - self.mean;
221                diff * diff
222            })
223            .sum::<f64>()
224            / n as f64;
225        self.std_dev = variance.sqrt();
226
227        // Calculate update magnitude
228        self.update_magnitude = params_before
229            .iter()
230            .zip(params_after.iter())
231            .map(|(&before, &after)| {
232                let diff = after.to_f64().expect("unwrap failed")
233                    - before.to_f64().expect("unwrap failed");
234                diff * diff
235            })
236            .sum::<f64>()
237            .sqrt();
238
239        // Calculate relative change
240        let params_norm: f64 = params_before
241            .iter()
242            .map(|&p| {
243                let val = p.to_f64().expect("unwrap failed");
244                val * val
245            })
246            .sum::<f64>()
247            .sqrt();
248
249        self.relative_change = if params_norm > 1e-10 {
250            self.update_magnitude / params_norm
251        } else {
252            0.0
253        };
254    }
255}
256
257/// Convergence metrics
258#[derive(Debug, Clone, Default)]
259pub struct ConvergenceMetrics {
260    /// Moving average of parameter updates
261    pub update_moving_avg: f64,
262    /// Is optimizer converging (updates decreasing)
263    pub is_converging: bool,
264    /// Estimated steps to convergence
265    pub estimated_steps_to_convergence: Option<u64>,
266    /// Convergence rate
267    pub convergence_rate: f64,
268}
269
270impl ConvergenceMetrics {
271    /// Update convergence metrics
272    pub fn update(&mut self, param_stats: &ParameterStatistics) {
273        // Check if converging before updating (compare against previous average)
274        if self.update_moving_avg > 1e-10 {
275            self.is_converging = param_stats.update_magnitude < self.update_moving_avg;
276            self.convergence_rate = 1.0 - (param_stats.update_magnitude / self.update_moving_avg);
277        }
278
279        // Update moving average with exponential smoothing (alpha = 0.1)
280        let alpha = 0.1;
281        self.update_moving_avg =
282            alpha * param_stats.update_magnitude + (1.0 - alpha) * self.update_moving_avg;
283    }
284}
285
286/// Metrics collector for tracking multiple optimizers
287pub struct MetricsCollector {
288    /// Metrics for each optimizer
289    metrics: HashMap<String, OptimizerMetrics>,
290    /// Global start time
291    start_time: Instant,
292}
293
294impl MetricsCollector {
295    /// Create a new metrics collector
296    pub fn new() -> Self {
297        Self {
298            metrics: HashMap::new(),
299            start_time: Instant::now(),
300        }
301    }
302
303    /// Register a new optimizer for tracking
304    pub fn register_optimizer(&mut self, name: impl Into<String>) {
305        let name = name.into();
306        self.metrics
307            .entry(name.clone())
308            .or_insert_with(|| OptimizerMetrics::new(name));
309    }
310
311    /// Update metrics for an optimizer
312    pub fn update<A: Float + ScalarOperand>(
313        &mut self,
314        optimizer_name: &str,
315        step_duration: Duration,
316        learning_rate: f64,
317        gradients: &ArrayView1<A>,
318        params_before: &ArrayView1<A>,
319        params_after: &ArrayView1<A>,
320    ) -> Result<()> {
321        if let Some(metrics) = self.metrics.get_mut(optimizer_name) {
322            metrics.update_step(
323                step_duration,
324                learning_rate,
325                gradients,
326                params_before,
327                params_after,
328            );
329            Ok(())
330        } else {
331            Err(crate::error::OptimError::InvalidConfig(format!(
332                "Optimizer '{}' not registered",
333                optimizer_name
334            )))
335        }
336    }
337
338    /// Get metrics for an optimizer
339    pub fn get_metrics(&self, optimizer_name: &str) -> Option<&OptimizerMetrics> {
340        self.metrics.get(optimizer_name)
341    }
342
343    /// Get all metrics
344    pub fn all_metrics(&self) -> &HashMap<String, OptimizerMetrics> {
345        &self.metrics
346    }
347
348    /// Get elapsed time since collector started
349    pub fn elapsed(&self) -> Duration {
350        self.start_time.elapsed()
351    }
352
353    /// Reset all metrics
354    pub fn reset(&mut self) {
355        for metrics in self.metrics.values_mut() {
356            metrics.reset();
357        }
358        self.start_time = Instant::now();
359    }
360
361    /// Generate summary report
362    pub fn summary_report(&self) -> String {
363        let mut report = String::new();
364        report.push_str("=== Optimizer Metrics Summary ===\n");
365        report.push_str(&format!("Total elapsed time: {:?}\n\n", self.elapsed()));
366
367        for (name, metrics) in &self.metrics {
368            report.push_str(&format!("Optimizer: {}\n", name));
369            report.push_str(&format!("  Steps: {}\n", metrics.step_count));
370            report.push_str(&format!("  Avg step time: {:?}\n", metrics.avg_step_time));
371            report.push_str(&format!(
372                "  Throughput: {:.2} steps/sec\n",
373                metrics.throughput()
374            ));
375            report.push_str(&format!(
376                "  Learning rate: {:.6}\n",
377                metrics.current_learning_rate
378            ));
379            report.push_str(&format!(
380                "  Gradient norm: {:.6}\n",
381                metrics.gradient_stats.norm
382            ));
383            report.push_str(&format!(
384                "  Update magnitude: {:.6}\n",
385                metrics.parameter_stats.update_magnitude
386            ));
387            report.push_str(&format!(
388                "  Converging: {}\n",
389                metrics.convergence.is_converging
390            ));
391            report.push_str(&format!(
392                "  Memory usage: {} bytes\n\n",
393                metrics.memory_usage
394            ));
395        }
396
397        report
398    }
399}
400
401impl Default for MetricsCollector {
402    fn default() -> Self {
403        Self::new()
404    }
405}
406
407/// Metrics reporter for exporting metrics to various formats
408pub struct MetricsReporter;
409
410impl MetricsReporter {
411    /// Export metrics to JSON format
412    pub fn to_json(metrics: &OptimizerMetrics) -> String {
413        format!(
414            r#"{{
415  "name": "{}",
416  "step_count": {},
417  "avg_step_time_ms": {},
418  "throughput": {},
419  "learning_rate": {},
420  "gradient_norm": {},
421  "update_magnitude": {},
422  "is_converging": {}
423}}"#,
424            metrics.name,
425            metrics.step_count,
426            metrics.avg_step_time.as_millis(),
427            metrics.throughput(),
428            metrics.current_learning_rate,
429            metrics.gradient_stats.norm,
430            metrics.parameter_stats.update_magnitude,
431            metrics.convergence.is_converging
432        )
433    }
434
435    /// Export metrics to CSV format
436    pub fn to_csv_header() -> String {
437        "name,step_count,avg_step_time_ms,throughput,learning_rate,gradient_norm,update_magnitude,is_converging".to_string()
438    }
439
440    /// Export metrics to CSV row
441    pub fn to_csv(metrics: &OptimizerMetrics) -> String {
442        format!(
443            "{},{},{},{},{},{},{},{}",
444            metrics.name,
445            metrics.step_count,
446            metrics.avg_step_time.as_millis(),
447            metrics.throughput(),
448            metrics.current_learning_rate,
449            metrics.gradient_stats.norm,
450            metrics.parameter_stats.update_magnitude,
451            metrics.convergence.is_converging
452        )
453    }
454}
455
456#[cfg(test)]
457mod tests {
458    use super::*;
459    use scirs2_core::ndarray::Array1;
460
461    #[test]
462    fn test_optimizer_metrics_creation() {
463        let metrics = OptimizerMetrics::new("sgd");
464        assert_eq!(metrics.name, "sgd");
465        assert_eq!(metrics.step_count, 0);
466        assert_eq!(metrics.throughput(), 0.0);
467    }
468
469    #[test]
470    fn test_gradient_statistics() {
471        let mut stats = GradientStatistics::default();
472        let grads = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
473        stats.update(&grads.view());
474
475        assert!((stats.mean - 3.0).abs() < 1e-6);
476        assert!(stats.max > 4.9);
477        assert!(stats.min < 1.1);
478        assert!(stats.norm > 0.0);
479    }
480
481    #[test]
482    fn test_parameter_statistics() {
483        let mut stats = ParameterStatistics::default();
484        let before = Array1::from_vec(vec![1.0, 2.0, 3.0]);
485        let after = Array1::from_vec(vec![0.9, 1.9, 2.9]);
486        stats.update(&before.view(), &after.view());
487
488        assert!(stats.update_magnitude > 0.0);
489        assert!(stats.relative_change > 0.0);
490        assert!((stats.mean - 1.9).abs() < 1e-6);
491    }
492
493    #[test]
494    fn test_metrics_collector() {
495        let mut collector = MetricsCollector::new();
496        collector.register_optimizer("sgd");
497
498        let grads = Array1::from_vec(vec![0.1, 0.2, 0.3]);
499        let before = Array1::from_vec(vec![1.0, 2.0, 3.0]);
500        let after = Array1::from_vec(vec![0.99, 1.98, 2.97]);
501
502        let result = collector.update(
503            "sgd",
504            Duration::from_millis(10),
505            0.01,
506            &grads.view(),
507            &before.view(),
508            &after.view(),
509        );
510
511        assert!(result.is_ok());
512        let metrics = collector.get_metrics("sgd").expect("unwrap failed");
513        assert_eq!(metrics.step_count, 1);
514    }
515
516    #[test]
517    fn test_metrics_collector_multiple_updates() {
518        let mut collector = MetricsCollector::new();
519        collector.register_optimizer("adam");
520
521        let grads = Array1::from_vec(vec![0.1, 0.2]);
522        let before = Array1::from_vec(vec![1.0, 2.0]);
523        let after = Array1::from_vec(vec![0.99, 1.98]);
524
525        for _ in 0..10 {
526            collector
527                .update(
528                    "adam",
529                    Duration::from_millis(5),
530                    0.001,
531                    &grads.view(),
532                    &before.view(),
533                    &after.view(),
534                )
535                .expect("unwrap failed");
536        }
537
538        let metrics = collector.get_metrics("adam").expect("unwrap failed");
539        assert_eq!(metrics.step_count, 10);
540        assert!(metrics.throughput() > 0.0);
541    }
542
543    #[test]
544    fn test_metrics_reset() {
545        let mut metrics = OptimizerMetrics::new("test");
546        let grads = Array1::from_vec(vec![0.1]);
547        let before = Array1::from_vec(vec![1.0]);
548        let after = Array1::from_vec(vec![0.99]);
549
550        metrics.update_step(
551            Duration::from_millis(10),
552            0.01,
553            &grads.view(),
554            &before.view(),
555            &after.view(),
556        );
557
558        assert_eq!(metrics.step_count, 1);
559
560        metrics.reset();
561        assert_eq!(metrics.step_count, 0);
562        assert_eq!(metrics.total_step_time, Duration::ZERO);
563    }
564
565    #[test]
566    fn test_summary_report() {
567        let mut collector = MetricsCollector::new();
568        collector.register_optimizer("sgd");
569
570        let grads = Array1::from_vec(vec![0.1]);
571        let before = Array1::from_vec(vec![1.0]);
572        let after = Array1::from_vec(vec![0.99]);
573
574        collector
575            .update(
576                "sgd",
577                Duration::from_millis(10),
578                0.01,
579                &grads.view(),
580                &before.view(),
581                &after.view(),
582            )
583            .expect("unwrap failed");
584
585        let report = collector.summary_report();
586        assert!(report.contains("Optimizer: sgd"));
587        assert!(report.contains("Steps: 1"));
588    }
589
590    #[test]
591    fn test_metrics_reporter_json() {
592        let metrics = OptimizerMetrics::new("test");
593        let json = MetricsReporter::to_json(&metrics);
594        assert!(json.contains("\"name\": \"test\""));
595        assert!(json.contains("\"step_count\": 0"));
596    }
597
598    #[test]
599    fn test_metrics_reporter_csv() {
600        let metrics = OptimizerMetrics::new("test");
601        let header = MetricsReporter::to_csv_header();
602        let row = MetricsReporter::to_csv(&metrics);
603
604        assert!(header.contains("name"));
605        assert!(header.contains("step_count"));
606        assert!(row.starts_with("test,0,"));
607    }
608
609    #[test]
610    fn test_convergence_metrics() {
611        let mut convergence = ConvergenceMetrics::default();
612
613        // Update with some values
614        let mut param_stats = ParameterStatistics {
615            update_magnitude: 1.0,
616            ..Default::default()
617        };
618        convergence.update(&param_stats);
619        assert_eq!(convergence.update_moving_avg, 0.1);
620
621        param_stats.update_magnitude = 0.5;
622        convergence.update(&param_stats);
623        // update_moving_avg = 0.1 * 0.5 + 0.9 * 0.1 = 0.14
624        assert!((convergence.update_moving_avg - 0.14).abs() < 1e-6);
625
626        // Verify convergence detection works
627        param_stats.update_magnitude = 0.05;
628        convergence.update(&param_stats);
629        // Should detect converging since 0.05 < 0.14
630        assert!(convergence.is_converging);
631        assert!(convergence.update_moving_avg > 0.0);
632    }
633}