use std::fmt;
use std::marker::PhantomData;
use scivex_core::Float;
use crate::error::{NnError, Result};
use crate::optim::Optimizer;
use crate::variable::Variable;
pub struct GradAccumulator<T: Float, O: Optimizer<T>> {
optimizer: O,
parameters: Vec<Variable<T>>,
accumulation_steps: usize,
current_step: usize,
_marker: PhantomData<T>,
}
impl<T: Float, O: Optimizer<T>> fmt::Debug for GradAccumulator<T, O> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("GradAccumulator")
.field("accumulation_steps", &self.accumulation_steps)
.field("current_step", &self.current_step)
.field("num_parameters", &self.parameters.len())
.finish_non_exhaustive()
}
}
impl<T: Float, O: Optimizer<T>> GradAccumulator<T, O> {
pub fn new(
optimizer: O,
parameters: Vec<Variable<T>>,
accumulation_steps: usize,
) -> Result<Self> {
if accumulation_steps == 0 {
return Err(NnError::InvalidParameter {
name: "accumulation_steps",
reason: "must be at least 1",
});
}
Ok(Self {
optimizer,
parameters,
accumulation_steps,
current_step: 0,
_marker: PhantomData,
})
}
pub fn step(&mut self) -> bool {
self.current_step += 1;
if self.current_step >= self.accumulation_steps {
self.scale_and_step();
true
} else {
false
}
}
pub fn flush(&mut self) {
if self.current_step > 0 {
self.scale_and_step();
}
}
pub fn reset(&mut self) {
self.current_step = 0;
self.optimizer.zero_grad();
}
pub fn inner(&self) -> &O {
&self.optimizer
}
pub fn inner_mut(&mut self) -> &mut O {
&mut self.optimizer
}
pub fn accumulation_steps(&self) -> usize {
self.accumulation_steps
}
pub fn current_step(&self) -> usize {
self.current_step
}
fn scale_and_step(&mut self) {
let n = self.current_step;
let scale = T::one() / T::from_usize(n);
for param in &self.parameters {
if let Some(g) = param.grad() {
let scaled = &g * scale;
param.set_grad(scaled);
}
}
self.optimizer.step();
self.optimizer.zero_grad();
self.current_step = 0;
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::optim::SGD;
use scivex_core::Tensor;
#[test]
fn test_construction() {
let w = Variable::new(Tensor::from_vec(vec![1.0_f64, 2.0], vec![2]).unwrap(), true);
let params = vec![w.clone()];
let sgd = SGD::new(params.clone(), 0.01);
let accum = GradAccumulator::new(sgd, params, 4).unwrap();
assert_eq!(accum.accumulation_steps(), 4);
assert_eq!(accum.current_step(), 0);
}
#[test]
fn test_zero_accumulation_steps_errors() {
let w = Variable::new(Tensor::from_vec(vec![1.0_f64], vec![1]).unwrap(), true);
let params = vec![w.clone()];
let sgd = SGD::new(params.clone(), 0.01);
let result = GradAccumulator::new(sgd, params, 0);
assert!(result.is_err());
match result.unwrap_err() {
NnError::InvalidParameter { name, .. } => {
assert_eq!(name, "accumulation_steps");
}
other => panic!("unexpected error: {other}"),
}
}
#[test]
fn test_step_counting() {
let w = Variable::new(Tensor::from_vec(vec![1.0_f64, 2.0], vec![2]).unwrap(), true);
let params = vec![w.clone()];
let sgd = SGD::new(params.clone(), 0.1);
let mut accum = GradAccumulator::new(sgd, params, 3).unwrap();
w.set_grad(Tensor::from_vec(vec![1.0, 1.0], vec![2]).unwrap());
assert!(!accum.step());
w.set_grad(Tensor::from_vec(vec![1.0, 1.0], vec![2]).unwrap());
assert!(!accum.step());
w.set_grad(Tensor::from_vec(vec![1.0, 1.0], vec![2]).unwrap());
assert!(accum.step());
assert_eq!(accum.current_step(), 0);
}
#[test]
fn test_gradient_scaling() {
let w = Variable::new(Tensor::from_vec(vec![0.0_f64], vec![1]).unwrap(), true);
let params = vec![w.clone()];
let sgd = SGD::new(params.clone(), 1.0);
let mut accum = GradAccumulator::new(sgd, params, 2).unwrap();
w.set_grad(Tensor::from_vec(vec![6.0_f64], vec![1]).unwrap());
assert!(!accum.step());
w.set_grad(Tensor::from_vec(vec![12.0_f64], vec![1]).unwrap());
assert!(accum.step());
let val = w.data().as_slice()[0];
assert!((val - (-6.0)).abs() < 1e-10, "expected w = -6.0, got {val}");
}
#[test]
fn test_flush_before_full_accumulation() {
let w = Variable::new(Tensor::from_vec(vec![0.0_f64], vec![1]).unwrap(), true);
let params = vec![w.clone()];
let sgd = SGD::new(params.clone(), 1.0);
let mut accum = GradAccumulator::new(sgd, params, 4).unwrap();
w.set_grad(Tensor::from_vec(vec![8.0_f64], vec![1]).unwrap());
accum.step();
accum.flush();
let val = w.data().as_slice()[0];
assert!((val - (-8.0)).abs() < 1e-10, "expected w = -8.0, got {val}");
assert_eq!(accum.current_step(), 0);
}
#[test]
fn test_flush_no_op_when_no_steps() {
let w = Variable::new(Tensor::from_vec(vec![5.0_f64], vec![1]).unwrap(), true);
let params = vec![w.clone()];
let sgd = SGD::new(params.clone(), 1.0);
let mut accum = GradAccumulator::new(sgd, params, 4).unwrap();
accum.flush();
let val = w.data().as_slice()[0];
assert!(
(val - 5.0).abs() < 1e-10,
"expected w unchanged at 5.0, got {val}"
);
}
}