optirs_core/gradient_accumulation/
mod.rs

1// Gradient accumulation for large batch training
2//
3// This module provides utilities for accumulating gradients across multiple
4// micro-batches to simulate larger batch sizes without increasing memory usage.
5
6use crate::error::{OptimError, Result};
7use scirs2_core::ndarray::{Array, Dimension, ScalarOperand, Zip};
8use scirs2_core::numeric::Float;
9use std::fmt::Debug;
10
11/// Type alias for adaptive step conditions
12pub type AdaptiveStepCondition = Box<dyn Fn(usize) -> bool>;
13
14/// Gradient accumulation mode
15#[derive(Debug, Clone, Copy, PartialEq)]
16pub enum AccumulationMode {
17    /// Sum gradients (standard accumulation)
18    Sum,
19    /// Average gradients (normalize by number of accumulations)
20    Average,
21}
22
23/// Gradient accumulator for micro-batch training
24#[derive(Debug)]
25pub struct GradientAccumulator<A: Float, D: Dimension> {
26    /// Accumulated gradients
27    accumulated_gradients: Vec<Array<A, D>>,
28    /// Number of accumulation steps taken
29    accumulation_count: usize,
30    /// Target number of accumulations before update
31    target_accumulations: usize,
32    /// Accumulation mode
33    mode: AccumulationMode,
34    /// Whether accumulator has been initialized
35    initialized: bool,
36}
37
38impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> GradientAccumulator<A, D> {
39    /// Create a new gradient accumulator
40    pub fn new(_targetaccumulations: usize, mode: AccumulationMode) -> Self {
41        Self {
42            accumulated_gradients: Vec::new(),
43            accumulation_count: 0,
44            target_accumulations: _targetaccumulations,
45            mode,
46            initialized: false,
47        }
48    }
49
50    /// Initialize accumulator with gradient shapes
51    pub fn initialize(&mut self, gradients: &[Array<A, D>]) -> Result<()> {
52        if self.initialized {
53            return Err(OptimError::InvalidConfig(
54                "Accumulator already initialized".to_string(),
55            ));
56        }
57
58        self.accumulated_gradients = gradients
59            .iter()
60            .map(|g| Array::zeros(g.raw_dim()))
61            .collect();
62
63        self.initialized = true;
64        Ok(())
65    }
66
67    /// Accumulate gradients from a micro-batch
68    pub fn accumulate(&mut self, gradients: &[Array<A, D>]) -> Result<()> {
69        if !self.initialized {
70            self.initialize(gradients)?;
71        }
72
73        if gradients.len() != self.accumulated_gradients.len() {
74            return Err(OptimError::DimensionMismatch(format!(
75                "Expected {} gradient arrays, got {}",
76                self.accumulated_gradients.len(),
77                gradients.len()
78            )));
79        }
80
81        // Accumulate gradients
82        for (acc_grad, micro_grad) in self.accumulated_gradients.iter_mut().zip(gradients.iter()) {
83            if acc_grad.raw_dim() != micro_grad.raw_dim() {
84                return Err(OptimError::DimensionMismatch(
85                    "Gradient dimensions don't match".to_string(),
86                ));
87            }
88
89            Zip::from(acc_grad).and(micro_grad).for_each(|acc, &micro| {
90                *acc = *acc + micro;
91            });
92        }
93
94        self.accumulation_count += 1;
95        Ok(())
96    }
97
98    /// Check if accumulation is complete
99    pub fn is_ready(&self) -> bool {
100        self.accumulation_count >= self.target_accumulations
101    }
102
103    /// Get accumulated gradients and reset accumulator
104    pub fn get_and_reset(&mut self) -> Result<Vec<Array<A, D>>> {
105        if !self.is_ready() {
106            return Err(OptimError::InvalidConfig(format!(
107                "Accumulation not ready: {}/{} steps completed",
108                self.accumulation_count, self.target_accumulations
109            )));
110        }
111
112        let mut result = self.accumulated_gradients.clone();
113
114        // Apply accumulation mode
115        match self.mode {
116            AccumulationMode::Sum => {
117                // Gradients are already summed, nothing to do
118            }
119            AccumulationMode::Average => {
120                let scale = A::one() / A::from(self.accumulation_count).unwrap();
121                for grad in &mut result {
122                    grad.mapv_inplace(|x| x * scale);
123                }
124            }
125        }
126
127        // Reset accumulator
128        self.reset();
129
130        Ok(result)
131    }
132
133    /// Reset accumulator state
134    pub fn reset(&mut self) {
135        for grad in &mut self.accumulated_gradients {
136            grad.fill(A::zero());
137        }
138        self.accumulation_count = 0;
139    }
140
141    /// Get current accumulation count
142    pub fn accumulation_count(&self) -> usize {
143        self.accumulation_count
144    }
145
146    /// Get target accumulation count
147    pub fn target_accumulations(&self) -> usize {
148        self.target_accumulations
149    }
150
151    /// Set new target accumulation count
152    pub fn set_target_accumulations(&mut self, target: usize) {
153        self.target_accumulations = target;
154    }
155
156    /// Get accumulation mode
157    pub fn mode(&self) -> AccumulationMode {
158        self.mode
159    }
160
161    /// Set accumulation mode
162    pub fn set_mode(&mut self, mode: AccumulationMode) {
163        self.mode = mode;
164    }
165
166    /// Check if accumulator is initialized
167    pub fn is_initialized(&self) -> bool {
168        self.initialized
169    }
170
171    /// Get current progress as a fraction (0.0 to 1.0)
172    pub fn progress(&self) -> f64 {
173        if self.target_accumulations == 0 {
174            1.0
175        } else {
176            self.accumulation_count as f64 / self.target_accumulations as f64
177        }
178    }
179}
180
181/// Variable accumulation scheduler
182pub struct VariableAccumulator<A: Float, D: Dimension> {
183    /// Base accumulator
184    accumulator: GradientAccumulator<A, D>,
185    /// Variable accumulation steps based on conditions
186    adaptive_steps: Vec<(AdaptiveStepCondition, usize)>,
187    /// Current step count
188    step_count: usize,
189}
190
191impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> VariableAccumulator<A, D> {
192    /// Create a new variable accumulator
193    pub fn new(_initialtarget: usize, mode: AccumulationMode) -> Self {
194        Self {
195            accumulator: GradientAccumulator::new(_initialtarget, mode),
196            adaptive_steps: Vec::new(),
197            step_count: 0,
198        }
199    }
200
201    /// Add a condition-based accumulation rule
202    pub fn add_adaptive_rule<F>(&mut self, condition: F, accumulationsteps: usize)
203    where
204        F: Fn(usize) -> bool + 'static,
205    {
206        self.adaptive_steps
207            .push((Box::new(condition), accumulationsteps));
208    }
209
210    /// Update target accumulations based on current step
211    fn update_target(&mut self) {
212        for (condition, steps) in &self.adaptive_steps {
213            if condition(self.step_count) {
214                self.accumulator.set_target_accumulations(*steps);
215                break;
216            }
217        }
218    }
219
220    /// Accumulate gradients with adaptive targeting
221    pub fn accumulate(&mut self, gradients: &[Array<A, D>]) -> Result<()> {
222        self.update_target();
223        self.accumulator.accumulate(gradients)
224    }
225
226    /// Check if accumulation is ready
227    pub fn is_ready(&self) -> bool {
228        self.accumulator.is_ready()
229    }
230
231    /// Get accumulated gradients and advance step
232    pub fn get_and_step(&mut self) -> Result<Vec<Array<A, D>>> {
233        let result = self.accumulator.get_and_reset()?;
234        self.step_count += 1;
235        Ok(result)
236    }
237
238    /// Get current step count
239    pub fn step_count(&self) -> usize {
240        self.step_count
241    }
242
243    /// Get underlying accumulator
244    pub fn accumulator(&self) -> &GradientAccumulator<A, D> {
245        &self.accumulator
246    }
247
248    /// Get mutable reference to underlying accumulator
249    pub fn accumulator_mut(&mut self) -> &mut GradientAccumulator<A, D> {
250        &mut self.accumulator
251    }
252}
253
254/// Micro-batch trainer that uses gradient accumulation
255#[derive(Debug)]
256pub struct MicroBatchTrainer<A: Float, D: Dimension> {
257    /// Gradient accumulator
258    accumulator: GradientAccumulator<A, D>,
259    /// Micro-batch size
260    micro_batch_size: usize,
261    /// Effective batch size (micro_batch_size * accumulation_steps)
262    effective_batch_size: usize,
263}
264
265impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> MicroBatchTrainer<A, D> {
266    /// Create a new micro-batch trainer
267    pub fn new(
268        micro_batch_size: usize,
269        effective_batch_size: usize,
270        mode: AccumulationMode,
271    ) -> Result<Self> {
272        if effective_batch_size < micro_batch_size {
273            return Err(OptimError::InvalidConfig(
274                "Effective batch _size must be >= micro batch _size".to_string(),
275            ));
276        }
277
278        let accumulation_steps = effective_batch_size / micro_batch_size;
279        let accumulator = GradientAccumulator::new(accumulation_steps, mode);
280
281        Ok(Self {
282            accumulator,
283            micro_batch_size,
284            effective_batch_size,
285        })
286    }
287
288    /// Process a micro-batch and accumulate gradients
289    pub fn process_micro_batch(&mut self, gradients: &[Array<A, D>]) -> Result<()> {
290        self.accumulator.accumulate(gradients)
291    }
292
293    /// Check if ready for optimizer step
294    pub fn ready_for_step(&self) -> bool {
295        self.accumulator.is_ready()
296    }
297
298    /// Get accumulated gradients for optimizer step
299    pub fn get_accumulated_gradients(&mut self) -> Result<Vec<Array<A, D>>> {
300        self.accumulator.get_and_reset()
301    }
302
303    /// Get micro-batch size
304    pub fn micro_batch_size(&self) -> usize {
305        self.micro_batch_size
306    }
307
308    /// Get effective batch size
309    pub fn effective_batch_size(&self) -> usize {
310        self.effective_batch_size
311    }
312
313    /// Get accumulation progress
314    pub fn progress(&self) -> f64 {
315        self.accumulator.progress()
316    }
317
318    /// Set new effective batch size
319    pub fn set_effective_batch_size(&mut self, effective_batchsize: usize) -> Result<()> {
320        if effective_batchsize < self.micro_batch_size {
321            return Err(OptimError::InvalidConfig(
322                "Effective batch _size must be >= micro batch _size".to_string(),
323            ));
324        }
325
326        self.effective_batch_size = effective_batchsize;
327        let accumulation_steps = effective_batchsize / self.micro_batch_size;
328        self.accumulator
329            .set_target_accumulations(accumulation_steps);
330        Ok(())
331    }
332}
333
334/// Utility functions for gradient accumulation
335pub mod utils {
336    use super::*;
337
338    /// Calculate optimal micro-batch size given memory constraints
339    pub fn calculate_micro_batch_size(
340        total_batch_size: usize,
341        max_memory_mb: usize,
342        param_count: usize,
343        bytes_per_param: usize,
344    ) -> usize {
345        // Estimate memory usage per sample
346        let memory_per_sample = param_count * bytes_per_param * 3; // params + grads + activations
347        let max_samples = (max_memory_mb * 1_000_000) / memory_per_sample;
348
349        // Choose micro-batch _size that divides total batch _size evenly
350        let mut micro_batch_size = max_samples.min(total_batch_size);
351        while !total_batch_size.is_multiple_of(micro_batch_size) && micro_batch_size > 1 {
352            micro_batch_size -= 1;
353        }
354
355        micro_batch_size.max(1)
356    }
357
358    /// Calculate accumulation steps needed
359    pub fn calculate_accumulation_steps(
360        _total_batch_size: usize,
361        micro_batch_size: usize,
362    ) -> usize {
363        _total_batch_size.div_ceil(micro_batch_size) // Ceiling division
364    }
365
366    /// Validate gradient accumulation configuration
367    pub fn validate_config(
368        micro_batch_size: usize,
369        effective_batch_size: usize,
370        accumulation_steps: usize,
371    ) -> Result<()> {
372        if micro_batch_size == 0 {
373            return Err(OptimError::InvalidConfig(
374                "Micro batch _size must be > 0".to_string(),
375            ));
376        }
377
378        if effective_batch_size == 0 {
379            return Err(OptimError::InvalidConfig(
380                "Effective batch _size must be > 0".to_string(),
381            ));
382        }
383
384        if accumulation_steps == 0 {
385            return Err(OptimError::InvalidConfig(
386                "Accumulation _steps must be > 0".to_string(),
387            ));
388        }
389
390        if effective_batch_size != micro_batch_size * accumulation_steps {
391            return Err(OptimError::InvalidConfig(format!(
392                "Effective batch _size ({}) != micro batch _size ({}) * accumulation _steps ({})",
393                effective_batch_size, micro_batch_size, accumulation_steps
394            )));
395        }
396
397        Ok(())
398    }
399}
400
401#[cfg(test)]
402mod tests {
403    use super::*;
404    use approx::assert_relative_eq;
405    use scirs2_core::ndarray::Array1;
406
407    #[test]
408    fn test_gradient_accumulator_sum() {
409        let mut accumulator = GradientAccumulator::new(3, AccumulationMode::Sum);
410
411        // First micro-batch
412        let grad1 = vec![Array1::from_vec(vec![1.0, 2.0, 3.0])];
413        accumulator.accumulate(&grad1).unwrap();
414        assert!(!accumulator.is_ready());
415
416        // Second micro-batch
417        let grad2 = vec![Array1::from_vec(vec![2.0, 3.0, 4.0])];
418        accumulator.accumulate(&grad2).unwrap();
419        assert!(!accumulator.is_ready());
420
421        // Third micro-batch
422        let grad3 = vec![Array1::from_vec(vec![1.0, 1.0, 1.0])];
423        accumulator.accumulate(&grad3).unwrap();
424        assert!(accumulator.is_ready());
425
426        // Get accumulated gradients
427        let result = accumulator.get_and_reset().unwrap();
428        assert_eq!(result.len(), 1);
429        assert_eq!(result[0].as_slice().unwrap(), &[4.0, 6.0, 8.0]); // Sum of all gradients
430
431        // Should be reset
432        assert!(!accumulator.is_ready());
433        assert_eq!(accumulator.accumulation_count(), 0);
434    }
435
436    #[test]
437    fn test_gradient_accumulator_average() {
438        let mut accumulator = GradientAccumulator::new(2, AccumulationMode::Average);
439
440        let grad1 = vec![Array1::from_vec(vec![2.0, 4.0])];
441        let grad2 = vec![Array1::from_vec(vec![4.0, 2.0])];
442
443        accumulator.accumulate(&grad1).unwrap();
444        accumulator.accumulate(&grad2).unwrap();
445
446        let result = accumulator.get_and_reset().unwrap();
447        assert_eq!(result[0].as_slice().unwrap(), &[3.0, 3.0]); // Average of gradients
448    }
449
450    #[test]
451    fn test_variable_accumulator() {
452        let mut var_accumulator = VariableAccumulator::new(2, AccumulationMode::Sum);
453
454        // Add rule: if step > 5, use 4 accumulation steps
455        var_accumulator.add_adaptive_rule(|step| step > 5, 4);
456
457        // First few steps should use 2 accumulations
458        let grad = vec![Array1::from_vec(vec![1.0])];
459        var_accumulator.accumulate(&grad).unwrap();
460        var_accumulator.accumulate(&grad).unwrap();
461        assert!(var_accumulator.is_ready());
462
463        let _result = var_accumulator.get_and_step().unwrap();
464
465        // Simulate more steps to trigger adaptive rule
466        for _ in 0..6 {
467            var_accumulator.accumulate(&grad).unwrap();
468            var_accumulator.accumulate(&grad).unwrap();
469            if var_accumulator.is_ready() {
470                var_accumulator.get_and_step().unwrap();
471            }
472        }
473
474        // Now should require 4 accumulations
475        assert_eq!(var_accumulator.accumulator().target_accumulations(), 4);
476    }
477
478    #[test]
479    fn test_micro_batch_trainer() {
480        let mut trainer = MicroBatchTrainer::new(
481            2, // micro batch size
482            6, // effective batch size
483            AccumulationMode::Sum,
484        )
485        .unwrap();
486
487        assert_eq!(trainer.micro_batch_size(), 2);
488        assert_eq!(trainer.effective_batch_size(), 6);
489
490        let grad = vec![Array1::from_vec(vec![1.0, 1.0])];
491
492        // Process 3 micro-batches (to reach effective batch size of 6)
493        trainer.process_micro_batch(&grad).unwrap();
494        assert!(!trainer.ready_for_step());
495
496        trainer.process_micro_batch(&grad).unwrap();
497        assert!(!trainer.ready_for_step());
498
499        trainer.process_micro_batch(&grad).unwrap();
500        assert!(trainer.ready_for_step());
501
502        let result = trainer.get_accumulated_gradients().unwrap();
503        assert_eq!(result[0].as_slice().unwrap(), &[3.0, 3.0]); // Sum of 3 micro-batches
504    }
505
506    #[test]
507    fn test_calculate_micro_batch_size() {
508        let micro_batch = utils::calculate_micro_batch_size(
509            128,  // total batch size
510            100,  // max memory MB
511            1000, // param count
512            8,    // bytes per param (f64)
513        );
514
515        // Should return a size that divides 128 evenly
516        assert!(128 % micro_batch == 0);
517        assert!(micro_batch > 0);
518    }
519
520    #[test]
521    fn test_accumulation_steps_calculation() {
522        assert_eq!(utils::calculate_accumulation_steps(128, 32), 4);
523        assert_eq!(utils::calculate_accumulation_steps(100, 32), 4); // Ceiling division
524        assert_eq!(utils::calculate_accumulation_steps(96, 32), 3);
525    }
526
527    #[test]
528    fn test_config_validation() {
529        // Valid config
530        utils::validate_config(32, 128, 4).unwrap();
531
532        // Invalid: micro batch size is 0
533        assert!(utils::validate_config(0, 128, 4).is_err());
534
535        // Invalid: sizes don't match
536        assert!(utils::validate_config(32, 100, 4).is_err());
537    }
538
539    #[test]
540    fn test_accumulator_progress() {
541        let mut accumulator = GradientAccumulator::new(4, AccumulationMode::Sum);
542
543        assert_relative_eq!(accumulator.progress(), 0.0);
544
545        let grad = vec![Array1::from_vec(vec![1.0])];
546
547        accumulator.accumulate(&grad).unwrap();
548        assert_relative_eq!(accumulator.progress(), 0.25);
549
550        accumulator.accumulate(&grad).unwrap();
551        assert_relative_eq!(accumulator.progress(), 0.5);
552
553        accumulator.accumulate(&grad).unwrap();
554        assert_relative_eq!(accumulator.progress(), 0.75);
555
556        accumulator.accumulate(&grad).unwrap();
557        assert_relative_eq!(accumulator.progress(), 1.0);
558    }
559
560    #[test]
561    fn test_dimension_mismatch_error() {
562        let mut accumulator = GradientAccumulator::new(2, AccumulationMode::Sum);
563
564        let grad1 = vec![Array1::from_vec(vec![1.0, 2.0])];
565        accumulator.accumulate(&grad1).unwrap();
566
567        // Try to accumulate gradients with different dimensions
568        let grad2 = vec![Array1::from_vec(vec![1.0, 2.0, 3.0])];
569        assert!(accumulator.accumulate(&grad2).is_err());
570
571        // Try to accumulate different number of arrays
572        let grad3 = vec![
573            Array1::from_vec(vec![1.0, 2.0]),
574            Array1::from_vec(vec![3.0, 4.0]),
575        ];
576        assert!(accumulator.accumulate(&grad3).is_err());
577    }
578
579    #[test]
580    fn test_get_before_ready_error() {
581        let mut accumulator = GradientAccumulator::new(3, AccumulationMode::Sum);
582
583        let grad = vec![Array1::from_vec(vec![1.0])];
584        accumulator.accumulate(&grad).unwrap();
585
586        // Try to get gradients before accumulation is complete
587        assert!(accumulator.get_and_reset().is_err());
588    }
589}