use std::fmt;
#[derive(Debug, Clone, PartialEq)]
pub enum GradError {
ShapeMismatch {
expected: usize,
got: usize,
},
TensorLengthMismatch {
param_idx: usize,
expected: usize,
got: usize,
},
EmptyBuffer,
}
impl fmt::Display for GradError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
GradError::ShapeMismatch { expected, got } => write!(
f,
"gradient shape mismatch: expected {} parameter tensors, got {}",
expected, got
),
GradError::TensorLengthMismatch {
param_idx,
expected,
got,
} => write!(
f,
"tensor length mismatch at param {}: expected {} elements, got {}",
param_idx, expected, got
),
GradError::EmptyBuffer => {
write!(f, "cannot finalize: no gradients have been accumulated")
}
}
}
}
impl std::error::Error for GradError {}
#[derive(Debug, Clone)]
pub struct GradAccumConfig {
pub accumulation_steps: usize,
pub normalize_by_steps: bool,
pub clip_grad_norm: Option<f32>,
pub sync_on_last: bool,
}
impl Default for GradAccumConfig {
fn default() -> Self {
Self {
accumulation_steps: 4,
normalize_by_steps: true,
clip_grad_norm: Some(1.0),
sync_on_last: true,
}
}
}
pub fn global_grad_norm(grads: &[Vec<f32>]) -> f32 {
let sum_sq: f32 = grads.iter().flat_map(|g| g.iter()).map(|v| v * v).sum();
sum_sq.sqrt()
}
pub fn clip_grad_norm_(grads: &mut [Vec<f32>], max_norm: f32) -> f32 {
let norm = global_grad_norm(grads);
if norm > max_norm {
let scale = max_norm / norm;
for param_grads in grads.iter_mut() {
for g in param_grads.iter_mut() {
*g *= scale;
}
}
}
norm
}
#[derive(Debug)]
pub struct GradientBuffer {
pub gradients: Vec<Vec<f32>>,
pub step_count: usize,
pub is_ready: bool,
}
impl GradientBuffer {
pub fn new(param_shapes: &[usize]) -> Self {
let gradients = param_shapes.iter().map(|&n| vec![0.0_f32; n]).collect();
Self {
gradients,
step_count: 0,
is_ready: false,
}
}
pub fn accumulate(
&mut self,
grads: &[Vec<f32>],
config: &GradAccumConfig,
) -> Result<(), GradError> {
if grads.len() != self.gradients.len() {
return Err(GradError::ShapeMismatch {
expected: self.gradients.len(),
got: grads.len(),
});
}
for (idx, (buf, incoming)) in
self.gradients.iter_mut().zip(grads.iter()).enumerate()
{
if buf.len() != incoming.len() {
return Err(GradError::TensorLengthMismatch {
param_idx: idx,
expected: buf.len(),
got: incoming.len(),
});
}
for (b, g) in buf.iter_mut().zip(incoming.iter()) {
*b += g;
}
}
self.step_count += 1;
if self.step_count == config.accumulation_steps {
self.is_ready = true;
}
Ok(())
}
pub fn finalize(
&mut self,
config: &GradAccumConfig,
) -> Result<Vec<Vec<f32>>, GradError> {
if self.step_count == 0 {
return Err(GradError::EmptyBuffer);
}
let mut result = self.gradients.clone();
if config.normalize_by_steps && self.step_count > 0 {
let divisor = self.step_count as f32;
for param_grads in result.iter_mut() {
for g in param_grads.iter_mut() {
*g /= divisor;
}
}
}
if let Some(max_norm) = config.clip_grad_norm {
clip_grad_norm_(&mut result, max_norm);
}
self.reset();
Ok(result)
}
pub fn reset(&mut self) {
for param_grads in self.gradients.iter_mut() {
for g in param_grads.iter_mut() {
*g = 0.0;
}
}
self.step_count = 0;
self.is_ready = false;
}
pub fn current_step(&self) -> usize {
self.step_count
}
}
#[derive(Debug, Clone, Default)]
pub struct GradAccumStats {
pub total_updates: u64,
pub total_micro_batches: u64,
pub clips_applied: u64,
pub mean_grad_norm_before_clip: f32,
pub max_grad_norm_seen: f32,
}
impl GradAccumStats {
pub fn record_update(&mut self, grad_norm: f32, was_clipped: bool) {
self.total_updates += 1;
if was_clipped {
self.clips_applied += 1;
}
let n = self.total_updates as f32;
self.mean_grad_norm_before_clip +=
(grad_norm - self.mean_grad_norm_before_clip) / n;
if grad_norm > self.max_grad_norm_seen {
self.max_grad_norm_seen = grad_norm;
}
}
}
pub struct GradientAccumulator {
pub config: GradAccumConfig,
pub buffer: GradientBuffer,
pub stats: GradAccumStats,
}
impl GradientAccumulator {
pub fn new(config: GradAccumConfig, param_shapes: &[usize]) -> Self {
let buffer = GradientBuffer::new(param_shapes);
Self {
config,
buffer,
stats: GradAccumStats::default(),
}
}
pub fn step(
&mut self,
grads: &[Vec<f32>],
) -> Result<Option<Vec<Vec<f32>>>, GradError> {
self.buffer.accumulate(grads, &self.config)?;
self.stats.total_micro_batches += 1;
if self.buffer.is_ready {
let norm_before = global_grad_norm(&self.buffer.gradients);
let result = self.buffer.finalize(&self.config)?;
let was_clipped = self
.config
.clip_grad_norm
.map(|max| norm_before > max)
.unwrap_or(false);
let stats_norm = if self.config.normalize_by_steps {
norm_before / self.config.accumulation_steps as f32
} else {
norm_before
};
self.stats.record_update(stats_norm, was_clipped);
Ok(Some(result))
} else {
Ok(None)
}
}
pub fn force_finalize(
&mut self,
) -> Result<Option<Vec<Vec<f32>>>, GradError> {
if self.buffer.step_count == 0 {
return Ok(None);
}
let norm_before = global_grad_norm(&self.buffer.gradients);
let result = self.buffer.finalize(&self.config)?;
let was_clipped = self
.config
.clip_grad_norm
.map(|max| norm_before > max)
.unwrap_or(false);
let stats_norm = if self.config.normalize_by_steps && self.buffer.step_count > 0 {
norm_before
} else {
norm_before
};
self.stats.record_update(stats_norm, was_clipped);
Ok(Some(result))
}
pub fn stats(&self) -> &GradAccumStats {
&self.stats
}
}
#[cfg(test)]
mod tests {
use super::*;
fn bare_config(steps: usize) -> GradAccumConfig {
GradAccumConfig {
accumulation_steps: steps,
normalize_by_steps: false,
clip_grad_norm: None,
sync_on_last: true,
}
}
fn norm_config(steps: usize) -> GradAccumConfig {
GradAccumConfig {
accumulation_steps: steps,
normalize_by_steps: true,
clip_grad_norm: None,
sync_on_last: true,
}
}
#[test]
fn test_accumulate_4_steps_then_finalize() {
let config = bare_config(4);
let param_shapes = &[3_usize, 2];
let mut buf = GradientBuffer::new(param_shapes);
let grads_a = vec![vec![1.0_f32; 3], vec![1.0_f32; 2]];
for step in 0..4 {
let result = buf.accumulate(&grads_a, &config);
assert!(result.is_ok());
if step < 3 {
assert!(!buf.is_ready, "should not be ready before step 4");
} else {
assert!(buf.is_ready, "should be ready after step 4");
}
}
let finalized = buf.finalize(&config).expect("finalize should succeed");
assert_eq!(finalized.len(), 2);
for v in &finalized[0] {
assert!((v - 4.0).abs() < 1e-6, "expected 4.0, got {v}");
}
for v in &finalized[1] {
assert!((v - 4.0).abs() < 1e-6, "expected 4.0, got {v}");
}
assert_eq!(buf.step_count, 0);
assert!(!buf.is_ready);
}
#[test]
fn test_normalization_divides_by_steps() {
let config = norm_config(4);
let param_shapes = &[4_usize];
let mut buf = GradientBuffer::new(param_shapes);
let grads = vec![vec![2.0_f32; 4]];
for _ in 0..4 {
buf.accumulate(&grads, &config).expect("accumulate");
}
let finalized = buf.finalize(&config).expect("finalize");
for v in &finalized[0] {
assert!(
(v - 2.0).abs() < 1e-6,
"expected mean 2.0 after dividing sum 8.0 by 4, got {v}"
);
}
}
#[test]
fn test_gradient_clipping_scales_down() {
let mut grads = vec![vec![3.0_f32], vec![3.0_f32]];
let original_norm = clip_grad_norm_(&mut grads, 1.0);
let original_expected = (2.0_f32 * 9.0_f32).sqrt(); assert!(
(original_norm - original_expected).abs() < 1e-4,
"expected original norm ~{original_expected}, got {original_norm}"
);
let new_norm = global_grad_norm(&grads);
assert!(
(new_norm - 1.0).abs() < 1e-5,
"expected clipped norm 1.0, got {new_norm}"
);
}
#[test]
fn test_no_clipping_when_under_max() {
let mut grads = vec![vec![0.5_f32], vec![0.5_f32]];
let original = grads.clone();
clip_grad_norm_(&mut grads, 1.0);
assert_eq!(grads, original, "grads should not change when under max_norm");
}
#[test]
fn test_reset_clears_buffer() {
let config = bare_config(4);
let param_shapes = &[2_usize];
let mut buf = GradientBuffer::new(param_shapes);
buf.accumulate(&[vec![5.0, 5.0]], &config).expect("accumulate");
assert_eq!(buf.step_count, 1);
buf.reset();
assert_eq!(buf.step_count, 0);
assert!(!buf.is_ready);
assert_eq!(buf.gradients[0], vec![0.0_f32, 0.0_f32]);
}
#[test]
fn test_force_finalize_partial_batch() {
let config = bare_config(4); let param_shapes = &[2_usize];
let mut acc = GradientAccumulator::new(config, param_shapes);
let grads = vec![vec![1.0_f32, 2.0_f32]];
acc.step(&grads).expect("step 1");
acc.step(&grads).expect("step 2");
assert!(!acc.buffer.is_ready);
let forced = acc.force_finalize().expect("force_finalize ok");
assert!(forced.is_some(), "should have gradients");
let result = forced.expect("should be some");
assert!((result[0][0] - 2.0).abs() < 1e-6);
assert!((result[0][1] - 4.0).abs() < 1e-6);
}
#[test]
fn test_stats_tracking() {
let config = GradAccumConfig {
accumulation_steps: 2,
normalize_by_steps: false,
clip_grad_norm: Some(1.0),
sync_on_last: true,
};
let param_shapes = &[1_usize];
let mut acc = GradientAccumulator::new(config, param_shapes);
acc.step(&[vec![2.0_f32]]).expect("step 1a");
acc.step(&[vec![2.0_f32]]).expect("step 1b → update 1");
assert_eq!(acc.stats().total_updates, 1);
assert_eq!(acc.stats().total_micro_batches, 2);
assert_eq!(acc.stats().clips_applied, 1);
acc.step(&[vec![0.1_f32]]).expect("step 2a");
acc.step(&[vec![0.1_f32]]).expect("step 2b → update 2");
assert_eq!(acc.stats().total_updates, 2);
assert_eq!(acc.stats().clips_applied, 1, "no additional clip expected");
}
#[test]
fn test_multiple_parameters() {
let config = bare_config(2);
let param_shapes = &[3_usize, 5, 2];
let mut buf = GradientBuffer::new(param_shapes);
let g = vec![
vec![1.0_f32, 2.0, 3.0],
vec![0.5_f32; 5],
vec![10.0_f32, 20.0],
];
buf.accumulate(&g, &config).expect("step 1");
buf.accumulate(&g, &config).expect("step 2");
let finalized = buf.finalize(&config).expect("finalize");
assert_eq!(finalized[0], vec![2.0, 4.0, 6.0]);
for v in &finalized[1] {
assert!((v - 1.0).abs() < 1e-6);
}
assert_eq!(finalized[2], vec![20.0, 40.0]);
}
#[test]
fn test_error_on_shape_mismatch() {
let config = bare_config(4);
let param_shapes = &[3_usize, 2];
let mut buf = GradientBuffer::new(param_shapes);
let grads = vec![vec![1.0_f32; 3]];
let result = buf.accumulate(&grads, &config);
match result {
Err(GradError::ShapeMismatch { expected: 2, got: 1 }) => {} other => panic!("expected ShapeMismatch, got {:?}", other),
}
}
#[test]
fn test_clip_returns_original_norm() {
let mut grads = vec![vec![5.0_f32]];
let returned = clip_grad_norm_(&mut grads, 2.0);
assert!(
(returned - 5.0).abs() < 1e-5,
"expected original norm 5.0, got {returned}"
);
let after = global_grad_norm(&grads);
assert!(
(after - 2.0).abs() < 1e-5,
"expected clipped norm 2.0, got {after}"
);
}
#[test]
fn test_error_on_tensor_length_mismatch() {
let config = bare_config(4);
let param_shapes = &[3_usize];
let mut buf = GradientBuffer::new(param_shapes);
let grads = vec![vec![1.0_f32, 2.0]]; let result = buf.accumulate(&grads, &config);
match result {
Err(GradError::TensorLengthMismatch { param_idx: 0, expected: 3, got: 2 }) => {} other => panic!("expected TensorLengthMismatch, got {:?}", other),
}
}
#[test]
fn test_finalize_empty_buffer_error() {
let config = bare_config(4);
let param_shapes = &[2_usize];
let mut buf = GradientBuffer::new(param_shapes);
let result = buf.finalize(&config);
assert_eq!(result, Err(GradError::EmptyBuffer));
}
}