use crate::error::{OptimError, Result};
use scirs2_core::ndarray::{Array, Dimension, ScalarOperand, Zip};
use scirs2_core::numeric::Float;
use std::fmt::Debug;
pub type AdaptiveStepCondition = Box<dyn Fn(usize) -> bool>;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum AccumulationMode {
Sum,
Average,
}
#[derive(Debug)]
pub struct GradientAccumulator<A: Float, D: Dimension> {
accumulated_gradients: Vec<Array<A, D>>,
accumulation_count: usize,
target_accumulations: usize,
mode: AccumulationMode,
initialized: bool,
}
impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> GradientAccumulator<A, D> {
pub fn new(_targetaccumulations: usize, mode: AccumulationMode) -> Self {
Self {
accumulated_gradients: Vec::new(),
accumulation_count: 0,
target_accumulations: _targetaccumulations,
mode,
initialized: false,
}
}
pub fn initialize(&mut self, gradients: &[Array<A, D>]) -> Result<()> {
if self.initialized {
return Err(OptimError::InvalidConfig(
"Accumulator already initialized".to_string(),
));
}
self.accumulated_gradients = gradients
.iter()
.map(|g| Array::zeros(g.raw_dim()))
.collect();
self.initialized = true;
Ok(())
}
pub fn accumulate(&mut self, gradients: &[Array<A, D>]) -> Result<()> {
if !self.initialized {
self.initialize(gradients)?;
}
if gradients.len() != self.accumulated_gradients.len() {
return Err(OptimError::DimensionMismatch(format!(
"Expected {} gradient arrays, got {}",
self.accumulated_gradients.len(),
gradients.len()
)));
}
for (acc_grad, micro_grad) in self.accumulated_gradients.iter_mut().zip(gradients.iter()) {
if acc_grad.raw_dim() != micro_grad.raw_dim() {
return Err(OptimError::DimensionMismatch(
"Gradient dimensions don't match".to_string(),
));
}
Zip::from(acc_grad).and(micro_grad).for_each(|acc, µ| {
*acc = *acc + micro;
});
}
self.accumulation_count += 1;
Ok(())
}
pub fn is_ready(&self) -> bool {
self.accumulation_count >= self.target_accumulations
}
pub fn get_and_reset(&mut self) -> Result<Vec<Array<A, D>>> {
if !self.is_ready() {
return Err(OptimError::InvalidConfig(format!(
"Accumulation not ready: {}/{} steps completed",
self.accumulation_count, self.target_accumulations
)));
}
let mut result = self.accumulated_gradients.clone();
match self.mode {
AccumulationMode::Sum => {
}
AccumulationMode::Average => {
let scale = A::one() / A::from(self.accumulation_count).expect("unwrap failed");
for grad in &mut result {
grad.mapv_inplace(|x| x * scale);
}
}
}
self.reset();
Ok(result)
}
pub fn reset(&mut self) {
for grad in &mut self.accumulated_gradients {
grad.fill(A::zero());
}
self.accumulation_count = 0;
}
pub fn accumulation_count(&self) -> usize {
self.accumulation_count
}
pub fn target_accumulations(&self) -> usize {
self.target_accumulations
}
pub fn set_target_accumulations(&mut self, target: usize) {
self.target_accumulations = target;
}
pub fn mode(&self) -> AccumulationMode {
self.mode
}
pub fn set_mode(&mut self, mode: AccumulationMode) {
self.mode = mode;
}
pub fn is_initialized(&self) -> bool {
self.initialized
}
pub fn progress(&self) -> f64 {
if self.target_accumulations == 0 {
1.0
} else {
self.accumulation_count as f64 / self.target_accumulations as f64
}
}
}
pub struct VariableAccumulator<A: Float, D: Dimension> {
accumulator: GradientAccumulator<A, D>,
adaptive_steps: Vec<(AdaptiveStepCondition, usize)>,
step_count: usize,
}
impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> VariableAccumulator<A, D> {
pub fn new(_initialtarget: usize, mode: AccumulationMode) -> Self {
Self {
accumulator: GradientAccumulator::new(_initialtarget, mode),
adaptive_steps: Vec::new(),
step_count: 0,
}
}
pub fn add_adaptive_rule<F>(&mut self, condition: F, accumulationsteps: usize)
where
F: Fn(usize) -> bool + 'static,
{
self.adaptive_steps
.push((Box::new(condition), accumulationsteps));
}
fn update_target(&mut self) {
for (condition, steps) in &self.adaptive_steps {
if condition(self.step_count) {
self.accumulator.set_target_accumulations(*steps);
break;
}
}
}
pub fn accumulate(&mut self, gradients: &[Array<A, D>]) -> Result<()> {
self.update_target();
self.accumulator.accumulate(gradients)
}
pub fn is_ready(&self) -> bool {
self.accumulator.is_ready()
}
pub fn get_and_step(&mut self) -> Result<Vec<Array<A, D>>> {
let result = self.accumulator.get_and_reset()?;
self.step_count += 1;
Ok(result)
}
pub fn step_count(&self) -> usize {
self.step_count
}
pub fn accumulator(&self) -> &GradientAccumulator<A, D> {
&self.accumulator
}
pub fn accumulator_mut(&mut self) -> &mut GradientAccumulator<A, D> {
&mut self.accumulator
}
}
#[derive(Debug)]
pub struct MicroBatchTrainer<A: Float, D: Dimension> {
accumulator: GradientAccumulator<A, D>,
micro_batch_size: usize,
effective_batch_size: usize,
}
impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> MicroBatchTrainer<A, D> {
pub fn new(
micro_batch_size: usize,
effective_batch_size: usize,
mode: AccumulationMode,
) -> Result<Self> {
if effective_batch_size < micro_batch_size {
return Err(OptimError::InvalidConfig(
"Effective batch _size must be >= micro batch _size".to_string(),
));
}
let accumulation_steps = effective_batch_size / micro_batch_size;
let accumulator = GradientAccumulator::new(accumulation_steps, mode);
Ok(Self {
accumulator,
micro_batch_size,
effective_batch_size,
})
}
pub fn process_micro_batch(&mut self, gradients: &[Array<A, D>]) -> Result<()> {
self.accumulator.accumulate(gradients)
}
pub fn ready_for_step(&self) -> bool {
self.accumulator.is_ready()
}
pub fn get_accumulated_gradients(&mut self) -> Result<Vec<Array<A, D>>> {
self.accumulator.get_and_reset()
}
pub fn micro_batch_size(&self) -> usize {
self.micro_batch_size
}
pub fn effective_batch_size(&self) -> usize {
self.effective_batch_size
}
pub fn progress(&self) -> f64 {
self.accumulator.progress()
}
pub fn set_effective_batch_size(&mut self, effective_batchsize: usize) -> Result<()> {
if effective_batchsize < self.micro_batch_size {
return Err(OptimError::InvalidConfig(
"Effective batch _size must be >= micro batch _size".to_string(),
));
}
self.effective_batch_size = effective_batchsize;
let accumulation_steps = effective_batchsize / self.micro_batch_size;
self.accumulator
.set_target_accumulations(accumulation_steps);
Ok(())
}
}
pub mod utils {
use super::*;
pub fn calculate_micro_batch_size(
total_batch_size: usize,
max_memory_mb: usize,
param_count: usize,
bytes_per_param: usize,
) -> usize {
let memory_per_sample = param_count * bytes_per_param * 3; let max_samples = (max_memory_mb * 1_000_000) / memory_per_sample;
let mut micro_batch_size = max_samples.min(total_batch_size);
while !total_batch_size.is_multiple_of(micro_batch_size) && micro_batch_size > 1 {
micro_batch_size -= 1;
}
micro_batch_size.max(1)
}
pub fn calculate_accumulation_steps(
_total_batch_size: usize,
micro_batch_size: usize,
) -> usize {
_total_batch_size.div_ceil(micro_batch_size) }
pub fn validate_config(
micro_batch_size: usize,
effective_batch_size: usize,
accumulation_steps: usize,
) -> Result<()> {
if micro_batch_size == 0 {
return Err(OptimError::InvalidConfig(
"Micro batch _size must be > 0".to_string(),
));
}
if effective_batch_size == 0 {
return Err(OptimError::InvalidConfig(
"Effective batch _size must be > 0".to_string(),
));
}
if accumulation_steps == 0 {
return Err(OptimError::InvalidConfig(
"Accumulation _steps must be > 0".to_string(),
));
}
if effective_batch_size != micro_batch_size * accumulation_steps {
return Err(OptimError::InvalidConfig(format!(
"Effective batch _size ({}) != micro batch _size ({}) * accumulation _steps ({})",
effective_batch_size, micro_batch_size, accumulation_steps
)));
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use scirs2_core::ndarray::Array1;
#[test]
fn test_gradient_accumulator_sum() {
let mut accumulator = GradientAccumulator::new(3, AccumulationMode::Sum);
let grad1 = vec![Array1::from_vec(vec![1.0, 2.0, 3.0])];
accumulator.accumulate(&grad1).expect("unwrap failed");
assert!(!accumulator.is_ready());
let grad2 = vec![Array1::from_vec(vec![2.0, 3.0, 4.0])];
accumulator.accumulate(&grad2).expect("unwrap failed");
assert!(!accumulator.is_ready());
let grad3 = vec![Array1::from_vec(vec![1.0, 1.0, 1.0])];
accumulator.accumulate(&grad3).expect("unwrap failed");
assert!(accumulator.is_ready());
let result = accumulator.get_and_reset().expect("unwrap failed");
assert_eq!(result.len(), 1);
assert_eq!(
result[0].as_slice().expect("unwrap failed"),
&[4.0, 6.0, 8.0]
);
assert!(!accumulator.is_ready());
assert_eq!(accumulator.accumulation_count(), 0);
}
#[test]
fn test_gradient_accumulator_average() {
let mut accumulator = GradientAccumulator::new(2, AccumulationMode::Average);
let grad1 = vec![Array1::from_vec(vec![2.0, 4.0])];
let grad2 = vec![Array1::from_vec(vec![4.0, 2.0])];
accumulator.accumulate(&grad1).expect("unwrap failed");
accumulator.accumulate(&grad2).expect("unwrap failed");
let result = accumulator.get_and_reset().expect("unwrap failed");
assert_eq!(result[0].as_slice().expect("unwrap failed"), &[3.0, 3.0]); }
#[test]
fn test_variable_accumulator() {
let mut var_accumulator = VariableAccumulator::new(2, AccumulationMode::Sum);
var_accumulator.add_adaptive_rule(|step| step > 5, 4);
let grad = vec![Array1::from_vec(vec![1.0])];
var_accumulator.accumulate(&grad).expect("unwrap failed");
var_accumulator.accumulate(&grad).expect("unwrap failed");
assert!(var_accumulator.is_ready());
let _result = var_accumulator.get_and_step().expect("unwrap failed");
for _ in 0..6 {
var_accumulator.accumulate(&grad).expect("unwrap failed");
var_accumulator.accumulate(&grad).expect("unwrap failed");
if var_accumulator.is_ready() {
var_accumulator.get_and_step().expect("unwrap failed");
}
}
assert_eq!(var_accumulator.accumulator().target_accumulations(), 4);
}
#[test]
fn test_micro_batch_trainer() {
let mut trainer = MicroBatchTrainer::new(
2, 6, AccumulationMode::Sum,
)
.expect("unwrap failed");
assert_eq!(trainer.micro_batch_size(), 2);
assert_eq!(trainer.effective_batch_size(), 6);
let grad = vec![Array1::from_vec(vec![1.0, 1.0])];
trainer.process_micro_batch(&grad).expect("unwrap failed");
assert!(!trainer.ready_for_step());
trainer.process_micro_batch(&grad).expect("unwrap failed");
assert!(!trainer.ready_for_step());
trainer.process_micro_batch(&grad).expect("unwrap failed");
assert!(trainer.ready_for_step());
let result = trainer.get_accumulated_gradients().expect("unwrap failed");
assert_eq!(result[0].as_slice().expect("unwrap failed"), &[3.0, 3.0]); }
#[test]
fn test_calculate_micro_batch_size() {
let micro_batch = utils::calculate_micro_batch_size(
128, 100, 1000, 8, );
assert!(128 % micro_batch == 0);
assert!(micro_batch > 0);
}
#[test]
fn test_accumulation_steps_calculation() {
assert_eq!(utils::calculate_accumulation_steps(128, 32), 4);
assert_eq!(utils::calculate_accumulation_steps(100, 32), 4); assert_eq!(utils::calculate_accumulation_steps(96, 32), 3);
}
#[test]
fn test_config_validation() {
utils::validate_config(32, 128, 4).expect("unwrap failed");
assert!(utils::validate_config(0, 128, 4).is_err());
assert!(utils::validate_config(32, 100, 4).is_err());
}
#[test]
fn test_accumulator_progress() {
let mut accumulator = GradientAccumulator::new(4, AccumulationMode::Sum);
assert_relative_eq!(accumulator.progress(), 0.0);
let grad = vec![Array1::from_vec(vec![1.0])];
accumulator.accumulate(&grad).expect("unwrap failed");
assert_relative_eq!(accumulator.progress(), 0.25);
accumulator.accumulate(&grad).expect("unwrap failed");
assert_relative_eq!(accumulator.progress(), 0.5);
accumulator.accumulate(&grad).expect("unwrap failed");
assert_relative_eq!(accumulator.progress(), 0.75);
accumulator.accumulate(&grad).expect("unwrap failed");
assert_relative_eq!(accumulator.progress(), 1.0);
}
#[test]
fn test_dimension_mismatch_error() {
let mut accumulator = GradientAccumulator::new(2, AccumulationMode::Sum);
let grad1 = vec![Array1::from_vec(vec![1.0, 2.0])];
accumulator.accumulate(&grad1).expect("unwrap failed");
let grad2 = vec![Array1::from_vec(vec![1.0, 2.0, 3.0])];
assert!(accumulator.accumulate(&grad2).is_err());
let grad3 = vec![
Array1::from_vec(vec![1.0, 2.0]),
Array1::from_vec(vec![3.0, 4.0]),
];
assert!(accumulator.accumulate(&grad3).is_err());
}
#[test]
fn test_get_before_ready_error() {
let mut accumulator = GradientAccumulator::new(3, AccumulationMode::Sum);
let grad = vec![Array1::from_vec(vec![1.0])];
accumulator.accumulate(&grad).expect("unwrap failed");
assert!(accumulator.get_and_reset().is_err());
}
}