use scirs2_core::ndarray::{s, Array1, ArrayView1, Ix1, ScalarOperand};
use scirs2_core::numeric::{Float, Zero};
use std::fmt::Debug;
use crate::error::Result;
use crate::optimizers::Optimizer;
pub struct GradientAccumulator<A: Float> {
accumulated: Array1<A>,
count: usize,
}
impl<A: Float + ScalarOperand + Debug + Zero> GradientAccumulator<A> {
pub fn new(size: usize) -> Self {
Self {
accumulated: Array1::zeros(size),
count: 0,
}
}
pub fn accumulate(&mut self, gradients: &ArrayView1<A>) -> Result<()> {
if gradients.len() != self.accumulated.len() {
return Err(crate::error::OptimError::DimensionMismatch(format!(
"Gradient size ({}) doesn't match accumulator size ({})",
gradients.len(),
self.accumulated.len()
)));
}
self.accumulated = &self.accumulated + gradients;
self.count += 1;
Ok(())
}
pub fn count(&self) -> usize {
self.count
}
pub fn average(&mut self) -> Result<Array1<A>> {
if self.count == 0 {
return Err(crate::error::OptimError::InvalidConfig(
"No gradients accumulated".to_string(),
));
}
let scale = A::from(self.count).expect("unwrap failed");
let averaged = &self.accumulated / scale;
self.reset();
Ok(averaged)
}
pub fn reset(&mut self) {
self.accumulated.fill(A::zero());
self.count = 0;
}
pub fn is_ready(&self, target: usize) -> bool {
self.count >= target
}
}
pub struct ChunkedOptimizer<O, A>
where
O: Optimizer<A, Ix1> + Clone,
A: Float + ScalarOperand + Debug,
{
base_optimizer: O,
chunk_size: usize,
_phantom: std::marker::PhantomData<A>,
}
impl<O, A> ChunkedOptimizer<O, A>
where
O: Optimizer<A, Ix1> + Clone,
A: Float + ScalarOperand + Debug,
{
pub fn new(base_optimizer: O, chunk_size: Option<usize>) -> Self {
let chunk_size = chunk_size.unwrap_or(1_000_000);
Self {
base_optimizer,
chunk_size,
_phantom: std::marker::PhantomData,
}
}
pub fn step_chunked(&mut self, params: &Array1<A>, gradients: &Array1<A>) -> Result<Array1<A>> {
if params.len() != gradients.len() {
return Err(crate::error::OptimError::DimensionMismatch(format!(
"Parameters ({}) and gradients ({}) must have same size",
params.len(),
gradients.len()
)));
}
let total_size = params.len();
let mut updated = Array1::zeros(total_size);
let num_chunks = total_size.div_ceil(self.chunk_size);
for chunk_idx in 0..num_chunks {
let start = chunk_idx * self.chunk_size;
let end = (start + self.chunk_size).min(total_size);
let params_chunk = params.slice(s![start..end]).to_owned();
let grads_chunk = gradients.slice(s![start..end]).to_owned();
let updated_chunk = self.base_optimizer.step(¶ms_chunk, &grads_chunk)?;
updated.slice_mut(s![start..end]).assign(&updated_chunk);
}
Ok(updated)
}
pub fn chunk_size(&self) -> usize {
self.chunk_size
}
pub fn num_chunks(&self, total_size: usize) -> usize {
total_size.div_ceil(self.chunk_size)
}
}
pub struct MemoryUsageEstimator;
impl MemoryUsageEstimator {
pub fn sgd(num_params: usize, dtype_size: usize) -> usize {
num_params * dtype_size * 2
}
pub fn sgd_with_momentum(num_params: usize, dtype_size: usize) -> usize {
num_params * dtype_size * 3
}
pub fn adam(num_params: usize, dtype_size: usize) -> usize {
num_params * dtype_size * 4
}
pub fn recommend_chunk_size(
total_params: usize,
available_memory_bytes: usize,
dtype_size: usize,
optimizer_state_multiplier: usize,
) -> usize {
let memory_per_param = dtype_size * optimizer_state_multiplier;
let max_params = available_memory_bytes / memory_per_param;
let safe_params = (max_params * 80) / 100;
safe_params.min(total_params).max(1024)
}
pub fn recommend_accumulation_steps(
target_batch_size: usize,
max_micro_batch_size: usize,
) -> usize {
target_batch_size.div_ceil(max_micro_batch_size)
}
pub fn estimate_peak_memory(
num_params: usize,
batch_size: usize,
sequence_length: usize,
dtype_size: usize,
optimizer_type: &str,
) -> usize {
let param_memory = num_params * dtype_size;
let grad_memory = num_params * dtype_size;
let optimizer_memory = match optimizer_type {
"sgd" => num_params * dtype_size,
"adam" | "adamw" => num_params * dtype_size * 2,
_ => num_params * dtype_size,
};
let hidden_dim = (num_params as f64).sqrt() as usize;
let activation_memory = batch_size * sequence_length * hidden_dim * dtype_size;
param_memory + grad_memory + optimizer_memory + activation_memory
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::optimizers::SGD;
use approx::assert_relative_eq;
#[test]
fn test_gradient_accumulator() {
let mut accumulator = GradientAccumulator::<f32>::new(100);
let grad1 = Array1::from_elem(100, 1.0);
let grad2 = Array1::from_elem(100, 2.0);
accumulator
.accumulate(&grad1.view())
.expect("unwrap failed");
accumulator
.accumulate(&grad2.view())
.expect("unwrap failed");
assert_eq!(accumulator.count(), 2);
assert!(accumulator.is_ready(2));
let avg = accumulator.average().expect("unwrap failed");
assert_relative_eq!(avg[0], 1.5, epsilon = 1e-6);
assert_eq!(accumulator.count(), 0);
}
#[test]
fn test_chunked_optimizer() {
let optimizer = SGD::new(0.01);
let mut chunked_opt = ChunkedOptimizer::new(optimizer, Some(10));
let params = Array1::from_vec((0..25).map(|i| i as f32).collect());
let gradients = Array1::from_elem(25, 0.1);
let updated = chunked_opt
.step_chunked(¶ms, &gradients)
.expect("unwrap failed");
assert_eq!(updated.len(), 25);
assert_relative_eq!(updated[0], 0.0 - 0.01 * 0.1, epsilon = 1e-6);
assert_eq!(chunked_opt.num_chunks(25), 3);
}
#[test]
fn test_memory_estimator_sgd() {
let mem = MemoryUsageEstimator::sgd(1_000_000, 4);
assert_eq!(mem, 8_000_000);
let mem = MemoryUsageEstimator::sgd_with_momentum(1_000_000, 4);
assert_eq!(mem, 12_000_000); }
#[test]
fn test_memory_estimator_adam() {
let mem = MemoryUsageEstimator::adam(1_000_000, 4);
assert_eq!(mem, 16_000_000); }
#[test]
fn test_recommend_chunk_size() {
let chunk_size = MemoryUsageEstimator::recommend_chunk_size(
100_000_000, 1_000_000_000, 4, 4, );
assert!(chunk_size > 40_000_000);
assert!(chunk_size < 60_000_000);
}
#[test]
fn test_recommend_accumulation_steps() {
let steps = MemoryUsageEstimator::recommend_accumulation_steps(128, 32);
assert_eq!(steps, 4);
let steps = MemoryUsageEstimator::recommend_accumulation_steps(100, 32);
assert_eq!(steps, 4); }
#[test]
fn test_estimate_peak_memory() {
let peak = MemoryUsageEstimator::estimate_peak_memory(
10_000_000, 32, 512, 4, "adam",
);
assert!(peak > 100_000_000); }
}