use crate::callbacks::core::Callback;
use crate::{TrainError, TrainResult, TrainingState};
use std::collections::HashMap;
pub struct GradientMonitor {
log_frequency: usize,
vanishing_threshold: f64,
exploding_threshold: f64,
pub gradient_norms: Vec<f64>,
pub gradient_means: Vec<f64>,
pub gradient_stds: Vec<f64>,
pub vanishing_count: usize,
pub exploding_count: usize,
batch_counter: usize,
}
impl GradientMonitor {
pub fn new(log_frequency: usize, vanishing_threshold: f64, exploding_threshold: f64) -> Self {
Self {
log_frequency,
vanishing_threshold,
exploding_threshold,
gradient_norms: Vec::new(),
gradient_means: Vec::new(),
gradient_stds: Vec::new(),
vanishing_count: 0,
exploding_count: 0,
batch_counter: 0,
}
}
fn compute_gradient_stats(&mut self, _state: &TrainingState) -> (f64, f64, f64) {
(1.0, 0.0, 0.1)
}
fn check_vanishing(&mut self, norm: f64) -> bool {
if norm < self.vanishing_threshold {
self.vanishing_count += 1;
return true;
}
false
}
fn check_exploding(&mut self, norm: f64) -> bool {
if norm > self.exploding_threshold {
self.exploding_count += 1;
return true;
}
false
}
fn print_stats(&self, norm: f64, mean: f64, std: f64) {
println!("Gradient Stats [Batch {}]:", self.batch_counter);
println!(" Norm: {:.6e}, Mean: {:.6e}, Std: {:.6e}", norm, mean, std);
if self.vanishing_count > 0 {
println!(
" Warning: Vanishing gradient warnings: {}",
self.vanishing_count
);
}
if self.exploding_count > 0 {
println!(
" Warning: Exploding gradient warnings: {}",
self.exploding_count
);
}
}
pub fn summary(&self) -> GradientSummary {
let avg_norm = if !self.gradient_norms.is_empty() {
self.gradient_norms.iter().sum::<f64>() / self.gradient_norms.len() as f64
} else {
0.0
};
GradientSummary {
total_batches: self.batch_counter,
average_norm: avg_norm,
vanishing_count: self.vanishing_count,
exploding_count: self.exploding_count,
}
}
}
#[derive(Debug, Clone)]
pub struct GradientSummary {
pub total_batches: usize,
pub average_norm: f64,
pub vanishing_count: usize,
pub exploding_count: usize,
}
impl Callback for GradientMonitor {
fn on_batch_end(&mut self, _batch: usize, state: &TrainingState) -> TrainResult<()> {
self.batch_counter += 1;
let (norm, mean, std) = self.compute_gradient_stats(state);
self.gradient_norms.push(norm);
self.gradient_means.push(mean);
self.gradient_stds.push(std);
let vanishing = self.check_vanishing(norm);
let exploding = self.check_exploding(norm);
if self.batch_counter.is_multiple_of(self.log_frequency) {
self.print_stats(norm, mean, std);
} else if vanishing || exploding {
self.print_stats(norm, mean, std);
}
Ok(())
}
fn on_train_end(&mut self, _state: &TrainingState) -> TrainResult<()> {
let summary = self.summary();
println!("\n=== Gradient Monitoring Summary ===");
println!("Total batches: {}", summary.total_batches);
println!("Average gradient norm: {:.6e}", summary.average_norm);
println!("Vanishing gradient warnings: {}", summary.vanishing_count);
println!("Exploding gradient warnings: {}", summary.exploding_count);
println!("====================================\n");
Ok(())
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum GradientScalingStrategy {
Average,
Sum,
Dynamic,
}
pub struct GradientAccumulationCallback {
accumulation_steps: usize,
current_step: usize,
accumulated_grads: HashMap<String, scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>>,
initialized: bool,
scaling_strategy: GradientScalingStrategy,
max_grad_norm: f64,
overflow_detected: bool,
total_cycles: usize,
clip_grad_norm: Option<f64>,
}
impl GradientAccumulationCallback {
pub fn new(accumulation_steps: usize) -> TrainResult<Self> {
Self::with_strategy(accumulation_steps, GradientScalingStrategy::Average)
}
pub fn with_strategy(
accumulation_steps: usize,
scaling_strategy: GradientScalingStrategy,
) -> TrainResult<Self> {
if accumulation_steps == 0 {
return Err(TrainError::CallbackError(
"Accumulation steps must be greater than 0".to_string(),
));
}
Ok(Self {
accumulation_steps,
current_step: 0,
accumulated_grads: HashMap::new(),
initialized: false,
scaling_strategy,
max_grad_norm: 0.0,
overflow_detected: false,
total_cycles: 0,
clip_grad_norm: None,
})
}
pub fn with_grad_clipping(mut self, max_norm: f64) -> Self {
self.clip_grad_norm = Some(max_norm);
self
}
pub fn accumulate(
&mut self,
gradients: &HashMap<String, scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>>,
) -> TrainResult<()> {
for grad in gradients.values() {
if grad.iter().any(|&x| x.is_nan() || x.is_infinite()) {
self.overflow_detected = true;
return Err(TrainError::CallbackError(
"Gradient overflow detected (NaN or Inf)".to_string(),
));
}
}
let grad_norm = self.compute_total_norm(gradients);
self.max_grad_norm = self.max_grad_norm.max(grad_norm);
if !self.initialized {
for (name, grad) in gradients {
let clipped_grad = if let Some(max_norm) = self.clip_grad_norm {
if grad_norm > max_norm {
let scale = max_norm / grad_norm;
grad * scale
} else {
grad.clone()
}
} else {
grad.clone()
};
self.accumulated_grads.insert(name.clone(), clipped_grad);
}
self.initialized = true;
} else {
for (name, grad) in gradients {
if let Some(acc_grad) = self.accumulated_grads.get_mut(name) {
let grad_to_add = if let Some(max_norm) = self.clip_grad_norm {
if grad_norm > max_norm {
let scale = max_norm / grad_norm;
grad * scale
} else {
grad.clone()
}
} else {
grad.clone()
};
*acc_grad = &*acc_grad + &grad_to_add;
}
}
}
self.current_step += 1;
Ok(())
}
fn compute_total_norm(
&self,
gradients: &HashMap<String, scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>>,
) -> f64 {
let mut total_norm_sq = 0.0;
for grad in gradients.values() {
total_norm_sq += grad.iter().map(|&x| x * x).sum::<f64>();
}
total_norm_sq.sqrt()
}
pub fn should_update(&self) -> bool {
self.current_step >= self.accumulation_steps
}
pub fn get_and_reset(
&mut self,
) -> HashMap<String, scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>> {
let scale = match self.scaling_strategy {
GradientScalingStrategy::Average => 1.0 / self.accumulation_steps as f64,
GradientScalingStrategy::Sum => 1.0,
GradientScalingStrategy::Dynamic => {
1.0 / self.current_step.max(1) as f64
}
};
let mut scaled_grads = HashMap::new();
for (name, grad) in &self.accumulated_grads {
scaled_grads.insert(name.clone(), grad * scale);
}
self.total_cycles += 1;
self.current_step = 0;
self.initialized = false;
self.accumulated_grads.clear();
self.max_grad_norm = 0.0;
self.overflow_detected = false;
scaled_grads
}
pub fn get_stats(&self) -> GradientAccumulationStats {
let memory_usage = self.estimate_memory_usage();
GradientAccumulationStats {
accumulation_steps: self.accumulation_steps,
current_step: self.current_step,
total_cycles: self.total_cycles,
max_grad_norm: self.max_grad_norm,
overflow_detected: self.overflow_detected,
num_parameters: self.accumulated_grads.len(),
memory_usage_mb: memory_usage,
}
}
fn estimate_memory_usage(&self) -> f64 {
let mut total_elements = 0usize;
for grad in self.accumulated_grads.values() {
total_elements += grad.len();
}
(total_elements * 8) as f64 / (1024.0 * 1024.0)
}
pub fn reset(&mut self) {
self.current_step = 0;
self.initialized = false;
self.accumulated_grads.clear();
self.max_grad_norm = 0.0;
self.overflow_detected = false;
}
}
#[derive(Debug, Clone)]
pub struct GradientAccumulationStats {
pub accumulation_steps: usize,
pub current_step: usize,
pub total_cycles: usize,
pub max_grad_norm: f64,
pub overflow_detected: bool,
pub num_parameters: usize,
pub memory_usage_mb: f64,
}
impl Callback for GradientAccumulationCallback {
fn on_epoch_begin(&mut self, _epoch: usize, _state: &TrainingState) -> TrainResult<()> {
self.current_step = 0;
self.initialized = false;
self.accumulated_grads.clear();
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array2;
fn create_test_gradients() -> HashMap<String, Array2<f64>> {
let mut grads = HashMap::new();
grads.insert(
"layer1".to_string(),
Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).expect("unwrap"),
);
grads.insert(
"layer2".to_string(),
Array2::from_shape_vec((2, 2), vec![0.5, 1.0, 1.5, 2.0]).expect("unwrap"),
);
grads
}
#[test]
fn test_gradient_accumulation_average_strategy() {
let mut accum = GradientAccumulationCallback::new(2).expect("unwrap");
let grads = create_test_gradients();
accum.accumulate(&grads).expect("unwrap");
assert_eq!(accum.current_step, 1);
assert!(!accum.should_update());
accum.accumulate(&grads).expect("unwrap");
assert_eq!(accum.current_step, 2);
assert!(accum.should_update());
let averaged = accum.get_and_reset();
let layer1 = averaged.get("layer1").expect("unwrap");
assert_eq!(layer1[[0, 0]], 1.0); assert_eq!(layer1[[0, 1]], 2.0);
assert_eq!(accum.current_step, 0);
}
#[test]
fn test_gradient_accumulation_sum_strategy() {
let mut accum =
GradientAccumulationCallback::with_strategy(2, GradientScalingStrategy::Sum)
.expect("unwrap");
let grads = create_test_gradients();
accum.accumulate(&grads).expect("unwrap");
accum.accumulate(&grads).expect("unwrap");
let summed = accum.get_and_reset();
let layer1 = summed.get("layer1").expect("unwrap");
assert_eq!(layer1[[0, 0]], 2.0); assert_eq!(layer1[[0, 1]], 4.0); }
#[test]
fn test_gradient_accumulation_dynamic_strategy() {
let mut accum =
GradientAccumulationCallback::with_strategy(4, GradientScalingStrategy::Dynamic)
.expect("unwrap");
let grads = create_test_gradients();
accum.accumulate(&grads).expect("unwrap");
accum.accumulate(&grads).expect("unwrap");
accum.accumulate(&grads).expect("unwrap");
let scaled = accum.get_and_reset();
let layer1 = scaled.get("layer1").expect("unwrap");
assert_eq!(layer1[[0, 0]], 1.0); }
#[test]
fn test_gradient_clipping_during_accumulation() {
let mut accum = GradientAccumulationCallback::new(2)
.expect("unwrap")
.with_grad_clipping(1.0);
let mut grads = HashMap::new();
grads.insert(
"layer1".to_string(),
Array2::from_shape_vec((2, 2), vec![10.0, 10.0, 10.0, 10.0]).expect("unwrap"),
);
accum.accumulate(&grads).expect("unwrap");
assert!(accum.max_grad_norm > 0.0);
let accumulated = &accum.accumulated_grads["layer1"];
let norm_sq: f64 = accumulated.iter().map(|&x| x * x).sum();
let norm = norm_sq.sqrt();
assert!(norm <= 1.1); }
#[test]
fn test_overflow_detection() {
let mut accum = GradientAccumulationCallback::new(2).expect("unwrap");
let mut grads = HashMap::new();
grads.insert(
"layer1".to_string(),
Array2::from_shape_vec((2, 2), vec![f64::NAN, 1.0, 2.0, 3.0]).expect("unwrap"),
);
let result = accum.accumulate(&grads);
assert!(result.is_err());
assert!(accum.overflow_detected);
}
#[test]
fn test_gradient_accumulation_stats() {
let mut accum = GradientAccumulationCallback::new(2).expect("unwrap");
let grads = create_test_gradients();
accum.accumulate(&grads).expect("unwrap");
accum.accumulate(&grads).expect("unwrap");
accum.get_and_reset();
let stats = accum.get_stats();
assert_eq!(stats.accumulation_steps, 2);
assert_eq!(stats.total_cycles, 1);
assert!(!stats.overflow_detected);
}
#[test]
fn test_memory_usage_estimation() {
let mut accum = GradientAccumulationCallback::new(2).expect("unwrap");
let grads = create_test_gradients();
accum.accumulate(&grads).expect("unwrap");
let stats = accum.get_stats();
assert!(stats.memory_usage_mb > 0.0);
assert_eq!(stats.num_parameters, 2); }
#[test]
fn test_gradient_accumulation_reset() {
let mut accum = GradientAccumulationCallback::new(2).expect("unwrap");
let grads = create_test_gradients();
accum.accumulate(&grads).expect("unwrap");
assert_eq!(accum.current_step, 1);
accum.reset();
assert_eq!(accum.current_step, 0);
assert!(!accum.initialized);
assert_eq!(accum.accumulated_grads.len(), 0);
}
#[test]
fn test_gradient_accumulation_zero_steps_error() {
let result = GradientAccumulationCallback::new(0);
assert!(result.is_err());
}
#[test]
fn test_gradient_accumulation_multiple_cycles() {
let mut accum = GradientAccumulationCallback::new(2).expect("unwrap");
let grads = create_test_gradients();
accum.accumulate(&grads).expect("unwrap");
accum.accumulate(&grads).expect("unwrap");
accum.get_and_reset();
accum.accumulate(&grads).expect("unwrap");
accum.accumulate(&grads).expect("unwrap");
accum.get_and_reset();
let stats = accum.get_stats();
assert_eq!(stats.total_cycles, 2);
}
#[test]
fn test_different_gradient_shapes() {
let mut accum = GradientAccumulationCallback::new(2).expect("unwrap");
let mut grads1 = HashMap::new();
grads1.insert(
"layer1".to_string(),
Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).expect("unwrap"),
);
let mut grads2 = HashMap::new();
grads2.insert(
"layer1".to_string(),
Array2::from_shape_vec((2, 3), vec![0.5, 1.0, 1.5, 2.0, 2.5, 3.0]).expect("unwrap"),
);
accum.accumulate(&grads1).expect("unwrap");
accum.accumulate(&grads2).expect("unwrap");
let averaged = accum.get_and_reset();
let layer1 = averaged.get("layer1").expect("unwrap");
assert_eq!(layer1.dim(), (2, 3));
assert_eq!(layer1[[0, 0]], 0.75); }
}