use scirs2_core::ndarray::ScalarOperand;
use scirs2_core::numeric::{Float, FromPrimitive};
use std::fmt::Debug;
#[derive(Debug, Clone)]
pub struct GradientAccumulationConfig {
pub accumulation_steps: usize,
pub normalize: bool,
}
impl Default for GradientAccumulationConfig {
fn default() -> Self {
Self {
accumulation_steps: 1,
normalize: true,
}
}
}
#[derive(Debug, Clone)]
pub struct GradientStats<F: Float + Debug + ScalarOperand + Send + Sync + FromPrimitive> {
pub avg_grad_norm: F,
pub max_grad_norm: F,
pub min_grad_norm: F,
}
#[derive(Debug)]
pub struct GradientAccumulator<F: Float + Debug + ScalarOperand + Send + Sync + FromPrimitive> {
pub config: GradientAccumulationConfig,
pub current_step: usize,
pub stats: Option<GradientStats<F>>,
}
impl<F: Float + Debug + ScalarOperand + Send + Sync + FromPrimitive> GradientAccumulator<F> {
pub fn new(config: GradientAccumulationConfig) -> Self {
Self {
config,
current_step: 0,
stats: None,
}
}
pub fn reset(&mut self) {
self.current_step = 0;
self.stats = None;
}
pub fn should_update(&self) -> bool {
self.current_step >= self.config.accumulation_steps
}
}