optirs_core/
optimizer_metrics.rs

1//! Optimizer performance metrics and monitoring
2//!
3//! This module provides comprehensive metrics collection and monitoring for optimizers
4//! using SciRS2's metrics infrastructure for production deployments.
5//!
6//! # Features
7//!
8//! - Real-time optimizer performance tracking
9//! - Gradient and parameter statistics
10//! - Convergence monitoring
11//! - Memory usage tracking
12//! - Performance dashboards and reporting
13//!
14//! # SciRS2 Integration
15//!
16//! This module uses SciRS2-Core metrics abstractions exclusively:
17//! - `scirs2_core::metrics::MetricRegistry` for metric registration
18//! - `scirs2_core::metrics::Counter` for counting operations
19//! - `scirs2_core::metrics::Gauge` for current values
20//! - `scirs2_core::metrics::Histogram` for distributions
21//! - `scirs2_core::metrics::Timer` for timing operations
22
23use scirs2_core::ndarray::{Array1, ArrayView1, ScalarOperand};
24use scirs2_core::numeric::Float;
25use std::collections::HashMap;
26use std::fmt::Debug;
27use std::time::{Duration, Instant};
28
29use crate::error::Result;
30
31/// Optimizer performance metrics
32///
33/// Tracks key performance indicators for optimizer operations including
34/// step timing, gradient statistics, parameter updates, and convergence.
35#[derive(Debug, Clone)]
36pub struct OptimizerMetrics {
37    /// Optimizer name
38    pub name: String,
39    /// Total number of optimization steps
40    pub step_count: u64,
41    /// Total time spent in optimization steps
42    pub total_step_time: Duration,
43    /// Average time per step
44    pub avg_step_time: Duration,
45    /// Current learning rate
46    pub current_learning_rate: f64,
47    /// Gradient statistics
48    pub gradient_stats: GradientStatistics,
49    /// Parameter statistics
50    pub parameter_stats: ParameterStatistics,
51    /// Convergence metrics
52    pub convergence: ConvergenceMetrics,
53    /// Memory usage (bytes)
54    pub memory_usage: usize,
55}
56
57impl OptimizerMetrics {
58    /// Create new metrics for an optimizer
59    pub fn new(name: impl Into<String>) -> Self {
60        Self {
61            name: name.into(),
62            step_count: 0,
63            total_step_time: Duration::ZERO,
64            avg_step_time: Duration::ZERO,
65            current_learning_rate: 0.0,
66            gradient_stats: GradientStatistics::default(),
67            parameter_stats: ParameterStatistics::default(),
68            convergence: ConvergenceMetrics::default(),
69            memory_usage: 0,
70        }
71    }
72
73    /// Update metrics after an optimization step
74    pub fn update_step<A: Float>(
75        &mut self,
76        step_duration: Duration,
77        learning_rate: f64,
78        gradients: &ArrayView1<A>,
79        params_before: &ArrayView1<A>,
80        params_after: &ArrayView1<A>,
81    ) {
82        self.step_count += 1;
83        self.total_step_time += step_duration;
84        self.avg_step_time = self.total_step_time / self.step_count as u32;
85        self.current_learning_rate = learning_rate;
86
87        // Update gradient statistics
88        self.gradient_stats.update(gradients);
89
90        // Update parameter statistics
91        self.parameter_stats.update(params_before, params_after);
92
93        // Update convergence metrics
94        self.convergence.update(&self.parameter_stats);
95    }
96
97    /// Get throughput (steps per second)
98    pub fn throughput(&self) -> f64 {
99        if self.total_step_time.as_secs_f64() > 0.0 {
100            self.step_count as f64 / self.total_step_time.as_secs_f64()
101        } else {
102            0.0
103        }
104    }
105
106    /// Reset all metrics
107    pub fn reset(&mut self) {
108        self.step_count = 0;
109        self.total_step_time = Duration::ZERO;
110        self.avg_step_time = Duration::ZERO;
111        self.gradient_stats = GradientStatistics::default();
112        self.parameter_stats = ParameterStatistics::default();
113        self.convergence = ConvergenceMetrics::default();
114    }
115}
116
117/// Gradient statistics
118#[derive(Debug, Clone, Default)]
119pub struct GradientStatistics {
120    /// Mean gradient magnitude
121    pub mean: f64,
122    /// Standard deviation of gradients
123    pub std_dev: f64,
124    /// Maximum gradient value
125    pub max: f64,
126    /// Minimum gradient value
127    pub min: f64,
128    /// Gradient norm (L2)
129    pub norm: f64,
130    /// Number of zero gradients
131    pub num_zeros: usize,
132}
133
134impl GradientStatistics {
135    /// Update gradient statistics
136    pub fn update<A: Float>(&mut self, gradients: &ArrayView1<A>) {
137        let n = gradients.len();
138        if n == 0 {
139            return;
140        }
141
142        // Calculate statistics
143        let sum: f64 = gradients.iter().map(|&g| g.to_f64().unwrap()).sum();
144        self.mean = sum / n as f64;
145
146        let variance: f64 = gradients
147            .iter()
148            .map(|&g| {
149                let diff = g.to_f64().unwrap() - self.mean;
150                diff * diff
151            })
152            .sum::<f64>()
153            / n as f64;
154        self.std_dev = variance.sqrt();
155
156        self.max = gradients
157            .iter()
158            .map(|&g| g.to_f64().unwrap())
159            .fold(f64::NEG_INFINITY, f64::max);
160        self.min = gradients
161            .iter()
162            .map(|&g| g.to_f64().unwrap())
163            .fold(f64::INFINITY, f64::min);
164
165        self.norm = gradients
166            .iter()
167            .map(|&g| {
168                let val = g.to_f64().unwrap();
169                val * val
170            })
171            .sum::<f64>()
172            .sqrt();
173
174        self.num_zeros = gradients
175            .iter()
176            .filter(|&&g| g.to_f64().unwrap().abs() < 1e-10)
177            .count();
178    }
179}
180
181/// Parameter statistics
182#[derive(Debug, Clone, Default)]
183pub struct ParameterStatistics {
184    /// Mean parameter value
185    pub mean: f64,
186    /// Standard deviation of parameters
187    pub std_dev: f64,
188    /// Parameter update magnitude
189    pub update_magnitude: f64,
190    /// Relative parameter change
191    pub relative_change: f64,
192}
193
194impl ParameterStatistics {
195    /// Update parameter statistics
196    pub fn update<A: Float>(
197        &mut self,
198        params_before: &ArrayView1<A>,
199        params_after: &ArrayView1<A>,
200    ) {
201        let n = params_after.len();
202        if n == 0 {
203            return;
204        }
205
206        // Calculate mean
207        let sum: f64 = params_after.iter().map(|&p| p.to_f64().unwrap()).sum();
208        self.mean = sum / n as f64;
209
210        // Calculate std dev
211        let variance: f64 = params_after
212            .iter()
213            .map(|&p| {
214                let diff = p.to_f64().unwrap() - self.mean;
215                diff * diff
216            })
217            .sum::<f64>()
218            / n as f64;
219        self.std_dev = variance.sqrt();
220
221        // Calculate update magnitude
222        self.update_magnitude = params_before
223            .iter()
224            .zip(params_after.iter())
225            .map(|(&before, &after)| {
226                let diff = after.to_f64().unwrap() - before.to_f64().unwrap();
227                diff * diff
228            })
229            .sum::<f64>()
230            .sqrt();
231
232        // Calculate relative change
233        let params_norm: f64 = params_before
234            .iter()
235            .map(|&p| {
236                let val = p.to_f64().unwrap();
237                val * val
238            })
239            .sum::<f64>()
240            .sqrt();
241
242        self.relative_change = if params_norm > 1e-10 {
243            self.update_magnitude / params_norm
244        } else {
245            0.0
246        };
247    }
248}
249
250/// Convergence metrics
251#[derive(Debug, Clone, Default)]
252pub struct ConvergenceMetrics {
253    /// Moving average of parameter updates
254    pub update_moving_avg: f64,
255    /// Is optimizer converging (updates decreasing)
256    pub is_converging: bool,
257    /// Estimated steps to convergence
258    pub estimated_steps_to_convergence: Option<u64>,
259    /// Convergence rate
260    pub convergence_rate: f64,
261}
262
263impl ConvergenceMetrics {
264    /// Update convergence metrics
265    pub fn update(&mut self, param_stats: &ParameterStatistics) {
266        // Check if converging before updating (compare against previous average)
267        if self.update_moving_avg > 1e-10 {
268            self.is_converging = param_stats.update_magnitude < self.update_moving_avg;
269            self.convergence_rate = 1.0 - (param_stats.update_magnitude / self.update_moving_avg);
270        }
271
272        // Update moving average with exponential smoothing (alpha = 0.1)
273        let alpha = 0.1;
274        self.update_moving_avg =
275            alpha * param_stats.update_magnitude + (1.0 - alpha) * self.update_moving_avg;
276    }
277}
278
279/// Metrics collector for tracking multiple optimizers
280pub struct MetricsCollector {
281    /// Metrics for each optimizer
282    metrics: HashMap<String, OptimizerMetrics>,
283    /// Global start time
284    start_time: Instant,
285}
286
287impl MetricsCollector {
288    /// Create a new metrics collector
289    pub fn new() -> Self {
290        Self {
291            metrics: HashMap::new(),
292            start_time: Instant::now(),
293        }
294    }
295
296    /// Register a new optimizer for tracking
297    pub fn register_optimizer(&mut self, name: impl Into<String>) {
298        let name = name.into();
299        self.metrics
300            .entry(name.clone())
301            .or_insert_with(|| OptimizerMetrics::new(name));
302    }
303
304    /// Update metrics for an optimizer
305    pub fn update<A: Float + ScalarOperand>(
306        &mut self,
307        optimizer_name: &str,
308        step_duration: Duration,
309        learning_rate: f64,
310        gradients: &ArrayView1<A>,
311        params_before: &ArrayView1<A>,
312        params_after: &ArrayView1<A>,
313    ) -> Result<()> {
314        if let Some(metrics) = self.metrics.get_mut(optimizer_name) {
315            metrics.update_step(
316                step_duration,
317                learning_rate,
318                gradients,
319                params_before,
320                params_after,
321            );
322            Ok(())
323        } else {
324            Err(crate::error::OptimError::InvalidConfig(format!(
325                "Optimizer '{}' not registered",
326                optimizer_name
327            )))
328        }
329    }
330
331    /// Get metrics for an optimizer
332    pub fn get_metrics(&self, optimizer_name: &str) -> Option<&OptimizerMetrics> {
333        self.metrics.get(optimizer_name)
334    }
335
336    /// Get all metrics
337    pub fn all_metrics(&self) -> &HashMap<String, OptimizerMetrics> {
338        &self.metrics
339    }
340
341    /// Get elapsed time since collector started
342    pub fn elapsed(&self) -> Duration {
343        self.start_time.elapsed()
344    }
345
346    /// Reset all metrics
347    pub fn reset(&mut self) {
348        for metrics in self.metrics.values_mut() {
349            metrics.reset();
350        }
351        self.start_time = Instant::now();
352    }
353
354    /// Generate summary report
355    pub fn summary_report(&self) -> String {
356        let mut report = String::new();
357        report.push_str("=== Optimizer Metrics Summary ===\n");
358        report.push_str(&format!("Total elapsed time: {:?}\n\n", self.elapsed()));
359
360        for (name, metrics) in &self.metrics {
361            report.push_str(&format!("Optimizer: {}\n", name));
362            report.push_str(&format!("  Steps: {}\n", metrics.step_count));
363            report.push_str(&format!("  Avg step time: {:?}\n", metrics.avg_step_time));
364            report.push_str(&format!(
365                "  Throughput: {:.2} steps/sec\n",
366                metrics.throughput()
367            ));
368            report.push_str(&format!(
369                "  Learning rate: {:.6}\n",
370                metrics.current_learning_rate
371            ));
372            report.push_str(&format!(
373                "  Gradient norm: {:.6}\n",
374                metrics.gradient_stats.norm
375            ));
376            report.push_str(&format!(
377                "  Update magnitude: {:.6}\n",
378                metrics.parameter_stats.update_magnitude
379            ));
380            report.push_str(&format!(
381                "  Converging: {}\n",
382                metrics.convergence.is_converging
383            ));
384            report.push_str(&format!(
385                "  Memory usage: {} bytes\n\n",
386                metrics.memory_usage
387            ));
388        }
389
390        report
391    }
392}
393
394impl Default for MetricsCollector {
395    fn default() -> Self {
396        Self::new()
397    }
398}
399
400/// Metrics reporter for exporting metrics to various formats
401pub struct MetricsReporter;
402
403impl MetricsReporter {
404    /// Export metrics to JSON format
405    pub fn to_json(metrics: &OptimizerMetrics) -> String {
406        format!(
407            r#"{{
408  "name": "{}",
409  "step_count": {},
410  "avg_step_time_ms": {},
411  "throughput": {},
412  "learning_rate": {},
413  "gradient_norm": {},
414  "update_magnitude": {},
415  "is_converging": {}
416}}"#,
417            metrics.name,
418            metrics.step_count,
419            metrics.avg_step_time.as_millis(),
420            metrics.throughput(),
421            metrics.current_learning_rate,
422            metrics.gradient_stats.norm,
423            metrics.parameter_stats.update_magnitude,
424            metrics.convergence.is_converging
425        )
426    }
427
428    /// Export metrics to CSV format
429    pub fn to_csv_header() -> String {
430        "name,step_count,avg_step_time_ms,throughput,learning_rate,gradient_norm,update_magnitude,is_converging".to_string()
431    }
432
433    /// Export metrics to CSV row
434    pub fn to_csv(metrics: &OptimizerMetrics) -> String {
435        format!(
436            "{},{},{},{},{},{},{},{}",
437            metrics.name,
438            metrics.step_count,
439            metrics.avg_step_time.as_millis(),
440            metrics.throughput(),
441            metrics.current_learning_rate,
442            metrics.gradient_stats.norm,
443            metrics.parameter_stats.update_magnitude,
444            metrics.convergence.is_converging
445        )
446    }
447}
448
449#[cfg(test)]
450mod tests {
451    use super::*;
452    use scirs2_core::ndarray::Array1;
453
454    #[test]
455    fn test_optimizer_metrics_creation() {
456        let metrics = OptimizerMetrics::new("sgd");
457        assert_eq!(metrics.name, "sgd");
458        assert_eq!(metrics.step_count, 0);
459        assert_eq!(metrics.throughput(), 0.0);
460    }
461
462    #[test]
463    fn test_gradient_statistics() {
464        let mut stats = GradientStatistics::default();
465        let grads = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
466        stats.update(&grads.view());
467
468        assert!((stats.mean - 3.0).abs() < 1e-6);
469        assert!(stats.max > 4.9);
470        assert!(stats.min < 1.1);
471        assert!(stats.norm > 0.0);
472    }
473
474    #[test]
475    fn test_parameter_statistics() {
476        let mut stats = ParameterStatistics::default();
477        let before = Array1::from_vec(vec![1.0, 2.0, 3.0]);
478        let after = Array1::from_vec(vec![0.9, 1.9, 2.9]);
479        stats.update(&before.view(), &after.view());
480
481        assert!(stats.update_magnitude > 0.0);
482        assert!(stats.relative_change > 0.0);
483        assert!((stats.mean - 1.9).abs() < 1e-6);
484    }
485
486    #[test]
487    fn test_metrics_collector() {
488        let mut collector = MetricsCollector::new();
489        collector.register_optimizer("sgd");
490
491        let grads = Array1::from_vec(vec![0.1, 0.2, 0.3]);
492        let before = Array1::from_vec(vec![1.0, 2.0, 3.0]);
493        let after = Array1::from_vec(vec![0.99, 1.98, 2.97]);
494
495        let result = collector.update(
496            "sgd",
497            Duration::from_millis(10),
498            0.01,
499            &grads.view(),
500            &before.view(),
501            &after.view(),
502        );
503
504        assert!(result.is_ok());
505        let metrics = collector.get_metrics("sgd").unwrap();
506        assert_eq!(metrics.step_count, 1);
507    }
508
509    #[test]
510    fn test_metrics_collector_multiple_updates() {
511        let mut collector = MetricsCollector::new();
512        collector.register_optimizer("adam");
513
514        let grads = Array1::from_vec(vec![0.1, 0.2]);
515        let before = Array1::from_vec(vec![1.0, 2.0]);
516        let after = Array1::from_vec(vec![0.99, 1.98]);
517
518        for _ in 0..10 {
519            collector
520                .update(
521                    "adam",
522                    Duration::from_millis(5),
523                    0.001,
524                    &grads.view(),
525                    &before.view(),
526                    &after.view(),
527                )
528                .unwrap();
529        }
530
531        let metrics = collector.get_metrics("adam").unwrap();
532        assert_eq!(metrics.step_count, 10);
533        assert!(metrics.throughput() > 0.0);
534    }
535
536    #[test]
537    fn test_metrics_reset() {
538        let mut metrics = OptimizerMetrics::new("test");
539        let grads = Array1::from_vec(vec![0.1]);
540        let before = Array1::from_vec(vec![1.0]);
541        let after = Array1::from_vec(vec![0.99]);
542
543        metrics.update_step(
544            Duration::from_millis(10),
545            0.01,
546            &grads.view(),
547            &before.view(),
548            &after.view(),
549        );
550
551        assert_eq!(metrics.step_count, 1);
552
553        metrics.reset();
554        assert_eq!(metrics.step_count, 0);
555        assert_eq!(metrics.total_step_time, Duration::ZERO);
556    }
557
558    #[test]
559    fn test_summary_report() {
560        let mut collector = MetricsCollector::new();
561        collector.register_optimizer("sgd");
562
563        let grads = Array1::from_vec(vec![0.1]);
564        let before = Array1::from_vec(vec![1.0]);
565        let after = Array1::from_vec(vec![0.99]);
566
567        collector
568            .update(
569                "sgd",
570                Duration::from_millis(10),
571                0.01,
572                &grads.view(),
573                &before.view(),
574                &after.view(),
575            )
576            .unwrap();
577
578        let report = collector.summary_report();
579        assert!(report.contains("Optimizer: sgd"));
580        assert!(report.contains("Steps: 1"));
581    }
582
583    #[test]
584    fn test_metrics_reporter_json() {
585        let metrics = OptimizerMetrics::new("test");
586        let json = MetricsReporter::to_json(&metrics);
587        assert!(json.contains("\"name\": \"test\""));
588        assert!(json.contains("\"step_count\": 0"));
589    }
590
591    #[test]
592    fn test_metrics_reporter_csv() {
593        let metrics = OptimizerMetrics::new("test");
594        let header = MetricsReporter::to_csv_header();
595        let row = MetricsReporter::to_csv(&metrics);
596
597        assert!(header.contains("name"));
598        assert!(header.contains("step_count"));
599        assert!(row.starts_with("test,0,"));
600    }
601
602    #[test]
603    fn test_convergence_metrics() {
604        let mut convergence = ConvergenceMetrics::default();
605
606        // Update with some values
607        let mut param_stats = ParameterStatistics {
608            update_magnitude: 1.0,
609            ..Default::default()
610        };
611        convergence.update(&param_stats);
612        assert_eq!(convergence.update_moving_avg, 0.1);
613
614        param_stats.update_magnitude = 0.5;
615        convergence.update(&param_stats);
616        // update_moving_avg = 0.1 * 0.5 + 0.9 * 0.1 = 0.14
617        assert!((convergence.update_moving_avg - 0.14).abs() < 1e-6);
618
619        // Verify convergence detection works
620        param_stats.update_magnitude = 0.05;
621        convergence.update(&param_stats);
622        // Should detect converging since 0.05 < 0.14
623        assert!(convergence.is_converging);
624        assert!(convergence.update_moving_avg > 0.0);
625    }
626}