use scirs2_core::ndarray::{Array1, ArrayView1, ScalarOperand};
use scirs2_core::numeric::Float;
use std::collections::HashMap;
use std::fmt::Debug;
use std::time::{Duration, Instant};
use crate::error::Result;
#[derive(Debug, Clone)]
pub struct OptimizerMetrics {
pub name: String,
pub step_count: u64,
pub total_step_time: Duration,
pub avg_step_time: Duration,
pub current_learning_rate: f64,
pub gradient_stats: GradientStatistics,
pub parameter_stats: ParameterStatistics,
pub convergence: ConvergenceMetrics,
pub memory_usage: usize,
}
impl OptimizerMetrics {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
step_count: 0,
total_step_time: Duration::ZERO,
avg_step_time: Duration::ZERO,
current_learning_rate: 0.0,
gradient_stats: GradientStatistics::default(),
parameter_stats: ParameterStatistics::default(),
convergence: ConvergenceMetrics::default(),
memory_usage: 0,
}
}
pub fn update_step<A: Float>(
&mut self,
step_duration: Duration,
learning_rate: f64,
gradients: &ArrayView1<A>,
params_before: &ArrayView1<A>,
params_after: &ArrayView1<A>,
) {
self.step_count += 1;
self.total_step_time += step_duration;
self.avg_step_time = self.total_step_time / self.step_count as u32;
self.current_learning_rate = learning_rate;
self.gradient_stats.update(gradients);
self.parameter_stats.update(params_before, params_after);
self.convergence.update(&self.parameter_stats);
}
pub fn throughput(&self) -> f64 {
if self.total_step_time.as_secs_f64() > 0.0 {
self.step_count as f64 / self.total_step_time.as_secs_f64()
} else {
0.0
}
}
pub fn reset(&mut self) {
self.step_count = 0;
self.total_step_time = Duration::ZERO;
self.avg_step_time = Duration::ZERO;
self.gradient_stats = GradientStatistics::default();
self.parameter_stats = ParameterStatistics::default();
self.convergence = ConvergenceMetrics::default();
}
}
#[derive(Debug, Clone, Default)]
pub struct GradientStatistics {
pub mean: f64,
pub std_dev: f64,
pub max: f64,
pub min: f64,
pub norm: f64,
pub num_zeros: usize,
}
impl GradientStatistics {
pub fn update<A: Float>(&mut self, gradients: &ArrayView1<A>) {
let n = gradients.len();
if n == 0 {
return;
}
let sum: f64 = gradients
.iter()
.map(|&g| g.to_f64().expect("unwrap failed"))
.sum();
self.mean = sum / n as f64;
let variance: f64 = gradients
.iter()
.map(|&g| {
let diff = g.to_f64().expect("unwrap failed") - self.mean;
diff * diff
})
.sum::<f64>()
/ n as f64;
self.std_dev = variance.sqrt();
self.max = gradients
.iter()
.map(|&g| g.to_f64().expect("unwrap failed"))
.fold(f64::NEG_INFINITY, f64::max);
self.min = gradients
.iter()
.map(|&g| g.to_f64().expect("unwrap failed"))
.fold(f64::INFINITY, f64::min);
self.norm = gradients
.iter()
.map(|&g| {
let val = g.to_f64().expect("unwrap failed");
val * val
})
.sum::<f64>()
.sqrt();
self.num_zeros = gradients
.iter()
.filter(|&&g| g.to_f64().expect("unwrap failed").abs() < 1e-10)
.count();
}
}
#[derive(Debug, Clone, Default)]
pub struct ParameterStatistics {
pub mean: f64,
pub std_dev: f64,
pub update_magnitude: f64,
pub relative_change: f64,
}
impl ParameterStatistics {
pub fn update<A: Float>(
&mut self,
params_before: &ArrayView1<A>,
params_after: &ArrayView1<A>,
) {
let n = params_after.len();
if n == 0 {
return;
}
let sum: f64 = params_after
.iter()
.map(|&p| p.to_f64().expect("unwrap failed"))
.sum();
self.mean = sum / n as f64;
let variance: f64 = params_after
.iter()
.map(|&p| {
let diff = p.to_f64().expect("unwrap failed") - self.mean;
diff * diff
})
.sum::<f64>()
/ n as f64;
self.std_dev = variance.sqrt();
self.update_magnitude = params_before
.iter()
.zip(params_after.iter())
.map(|(&before, &after)| {
let diff = after.to_f64().expect("unwrap failed")
- before.to_f64().expect("unwrap failed");
diff * diff
})
.sum::<f64>()
.sqrt();
let params_norm: f64 = params_before
.iter()
.map(|&p| {
let val = p.to_f64().expect("unwrap failed");
val * val
})
.sum::<f64>()
.sqrt();
self.relative_change = if params_norm > 1e-10 {
self.update_magnitude / params_norm
} else {
0.0
};
}
}
#[derive(Debug, Clone, Default)]
pub struct ConvergenceMetrics {
pub update_moving_avg: f64,
pub is_converging: bool,
pub estimated_steps_to_convergence: Option<u64>,
pub convergence_rate: f64,
}
impl ConvergenceMetrics {
pub fn update(&mut self, param_stats: &ParameterStatistics) {
if self.update_moving_avg > 1e-10 {
self.is_converging = param_stats.update_magnitude < self.update_moving_avg;
self.convergence_rate = 1.0 - (param_stats.update_magnitude / self.update_moving_avg);
}
let alpha = 0.1;
self.update_moving_avg =
alpha * param_stats.update_magnitude + (1.0 - alpha) * self.update_moving_avg;
}
}
pub struct MetricsCollector {
metrics: HashMap<String, OptimizerMetrics>,
start_time: Instant,
}
impl MetricsCollector {
pub fn new() -> Self {
Self {
metrics: HashMap::new(),
start_time: Instant::now(),
}
}
pub fn register_optimizer(&mut self, name: impl Into<String>) {
let name = name.into();
self.metrics
.entry(name.clone())
.or_insert_with(|| OptimizerMetrics::new(name));
}
pub fn update<A: Float + ScalarOperand>(
&mut self,
optimizer_name: &str,
step_duration: Duration,
learning_rate: f64,
gradients: &ArrayView1<A>,
params_before: &ArrayView1<A>,
params_after: &ArrayView1<A>,
) -> Result<()> {
if let Some(metrics) = self.metrics.get_mut(optimizer_name) {
metrics.update_step(
step_duration,
learning_rate,
gradients,
params_before,
params_after,
);
Ok(())
} else {
Err(crate::error::OptimError::InvalidConfig(format!(
"Optimizer '{}' not registered",
optimizer_name
)))
}
}
pub fn get_metrics(&self, optimizer_name: &str) -> Option<&OptimizerMetrics> {
self.metrics.get(optimizer_name)
}
pub fn all_metrics(&self) -> &HashMap<String, OptimizerMetrics> {
&self.metrics
}
pub fn elapsed(&self) -> Duration {
self.start_time.elapsed()
}
pub fn reset(&mut self) {
for metrics in self.metrics.values_mut() {
metrics.reset();
}
self.start_time = Instant::now();
}
pub fn summary_report(&self) -> String {
let mut report = String::new();
report.push_str("=== Optimizer Metrics Summary ===\n");
report.push_str(&format!("Total elapsed time: {:?}\n\n", self.elapsed()));
for (name, metrics) in &self.metrics {
report.push_str(&format!("Optimizer: {}\n", name));
report.push_str(&format!(" Steps: {}\n", metrics.step_count));
report.push_str(&format!(" Avg step time: {:?}\n", metrics.avg_step_time));
report.push_str(&format!(
" Throughput: {:.2} steps/sec\n",
metrics.throughput()
));
report.push_str(&format!(
" Learning rate: {:.6}\n",
metrics.current_learning_rate
));
report.push_str(&format!(
" Gradient norm: {:.6}\n",
metrics.gradient_stats.norm
));
report.push_str(&format!(
" Update magnitude: {:.6}\n",
metrics.parameter_stats.update_magnitude
));
report.push_str(&format!(
" Converging: {}\n",
metrics.convergence.is_converging
));
report.push_str(&format!(
" Memory usage: {} bytes\n\n",
metrics.memory_usage
));
}
report
}
}
impl Default for MetricsCollector {
fn default() -> Self {
Self::new()
}
}
pub struct MetricsReporter;
impl MetricsReporter {
pub fn to_json(metrics: &OptimizerMetrics) -> String {
format!(
r#"{{
"name": "{}",
"step_count": {},
"avg_step_time_ms": {},
"throughput": {},
"learning_rate": {},
"gradient_norm": {},
"update_magnitude": {},
"is_converging": {}
}}"#,
metrics.name,
metrics.step_count,
metrics.avg_step_time.as_millis(),
metrics.throughput(),
metrics.current_learning_rate,
metrics.gradient_stats.norm,
metrics.parameter_stats.update_magnitude,
metrics.convergence.is_converging
)
}
pub fn to_csv_header() -> String {
"name,step_count,avg_step_time_ms,throughput,learning_rate,gradient_norm,update_magnitude,is_converging".to_string()
}
pub fn to_csv(metrics: &OptimizerMetrics) -> String {
format!(
"{},{},{},{},{},{},{},{}",
metrics.name,
metrics.step_count,
metrics.avg_step_time.as_millis(),
metrics.throughput(),
metrics.current_learning_rate,
metrics.gradient_stats.norm,
metrics.parameter_stats.update_magnitude,
metrics.convergence.is_converging
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array1;
#[test]
fn test_optimizer_metrics_creation() {
let metrics = OptimizerMetrics::new("sgd");
assert_eq!(metrics.name, "sgd");
assert_eq!(metrics.step_count, 0);
assert_eq!(metrics.throughput(), 0.0);
}
#[test]
fn test_gradient_statistics() {
let mut stats = GradientStatistics::default();
let grads = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
stats.update(&grads.view());
assert!((stats.mean - 3.0).abs() < 1e-6);
assert!(stats.max > 4.9);
assert!(stats.min < 1.1);
assert!(stats.norm > 0.0);
}
#[test]
fn test_parameter_statistics() {
let mut stats = ParameterStatistics::default();
let before = Array1::from_vec(vec![1.0, 2.0, 3.0]);
let after = Array1::from_vec(vec![0.9, 1.9, 2.9]);
stats.update(&before.view(), &after.view());
assert!(stats.update_magnitude > 0.0);
assert!(stats.relative_change > 0.0);
assert!((stats.mean - 1.9).abs() < 1e-6);
}
#[test]
fn test_metrics_collector() {
let mut collector = MetricsCollector::new();
collector.register_optimizer("sgd");
let grads = Array1::from_vec(vec![0.1, 0.2, 0.3]);
let before = Array1::from_vec(vec![1.0, 2.0, 3.0]);
let after = Array1::from_vec(vec![0.99, 1.98, 2.97]);
let result = collector.update(
"sgd",
Duration::from_millis(10),
0.01,
&grads.view(),
&before.view(),
&after.view(),
);
assert!(result.is_ok());
let metrics = collector.get_metrics("sgd").expect("unwrap failed");
assert_eq!(metrics.step_count, 1);
}
#[test]
fn test_metrics_collector_multiple_updates() {
let mut collector = MetricsCollector::new();
collector.register_optimizer("adam");
let grads = Array1::from_vec(vec![0.1, 0.2]);
let before = Array1::from_vec(vec![1.0, 2.0]);
let after = Array1::from_vec(vec![0.99, 1.98]);
for _ in 0..10 {
collector
.update(
"adam",
Duration::from_millis(5),
0.001,
&grads.view(),
&before.view(),
&after.view(),
)
.expect("unwrap failed");
}
let metrics = collector.get_metrics("adam").expect("unwrap failed");
assert_eq!(metrics.step_count, 10);
assert!(metrics.throughput() > 0.0);
}
#[test]
fn test_metrics_reset() {
let mut metrics = OptimizerMetrics::new("test");
let grads = Array1::from_vec(vec![0.1]);
let before = Array1::from_vec(vec![1.0]);
let after = Array1::from_vec(vec![0.99]);
metrics.update_step(
Duration::from_millis(10),
0.01,
&grads.view(),
&before.view(),
&after.view(),
);
assert_eq!(metrics.step_count, 1);
metrics.reset();
assert_eq!(metrics.step_count, 0);
assert_eq!(metrics.total_step_time, Duration::ZERO);
}
#[test]
fn test_summary_report() {
let mut collector = MetricsCollector::new();
collector.register_optimizer("sgd");
let grads = Array1::from_vec(vec![0.1]);
let before = Array1::from_vec(vec![1.0]);
let after = Array1::from_vec(vec![0.99]);
collector
.update(
"sgd",
Duration::from_millis(10),
0.01,
&grads.view(),
&before.view(),
&after.view(),
)
.expect("unwrap failed");
let report = collector.summary_report();
assert!(report.contains("Optimizer: sgd"));
assert!(report.contains("Steps: 1"));
}
#[test]
fn test_metrics_reporter_json() {
let metrics = OptimizerMetrics::new("test");
let json = MetricsReporter::to_json(&metrics);
assert!(json.contains("\"name\": \"test\""));
assert!(json.contains("\"step_count\": 0"));
}
#[test]
fn test_metrics_reporter_csv() {
let metrics = OptimizerMetrics::new("test");
let header = MetricsReporter::to_csv_header();
let row = MetricsReporter::to_csv(&metrics);
assert!(header.contains("name"));
assert!(header.contains("step_count"));
assert!(row.starts_with("test,0,"));
}
#[test]
fn test_convergence_metrics() {
let mut convergence = ConvergenceMetrics::default();
let mut param_stats = ParameterStatistics {
update_magnitude: 1.0,
..Default::default()
};
convergence.update(¶m_stats);
assert_eq!(convergence.update_moving_avg, 0.1);
param_stats.update_magnitude = 0.5;
convergence.update(¶m_stats);
assert!((convergence.update_moving_avg - 0.14).abs() < 1e-6);
param_stats.update_magnitude = 0.05;
convergence.update(¶m_stats);
assert!(convergence.is_converging);
assert!(convergence.update_moving_avg > 0.0);
}
}