optirs_core/gradient_accumulation/
mod.rs1use crate::error::{OptimError, Result};
7use scirs2_core::ndarray::{Array, Dimension, ScalarOperand, Zip};
8use scirs2_core::numeric::Float;
9use std::fmt::Debug;
10
11pub type AdaptiveStepCondition = Box<dyn Fn(usize) -> bool>;
13
14#[derive(Debug, Clone, Copy, PartialEq)]
16pub enum AccumulationMode {
17 Sum,
19 Average,
21}
22
23#[derive(Debug)]
25pub struct GradientAccumulator<A: Float, D: Dimension> {
26 accumulated_gradients: Vec<Array<A, D>>,
28 accumulation_count: usize,
30 target_accumulations: usize,
32 mode: AccumulationMode,
34 initialized: bool,
36}
37
38impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> GradientAccumulator<A, D> {
39 pub fn new(_targetaccumulations: usize, mode: AccumulationMode) -> Self {
41 Self {
42 accumulated_gradients: Vec::new(),
43 accumulation_count: 0,
44 target_accumulations: _targetaccumulations,
45 mode,
46 initialized: false,
47 }
48 }
49
50 pub fn initialize(&mut self, gradients: &[Array<A, D>]) -> Result<()> {
52 if self.initialized {
53 return Err(OptimError::InvalidConfig(
54 "Accumulator already initialized".to_string(),
55 ));
56 }
57
58 self.accumulated_gradients = gradients
59 .iter()
60 .map(|g| Array::zeros(g.raw_dim()))
61 .collect();
62
63 self.initialized = true;
64 Ok(())
65 }
66
67 pub fn accumulate(&mut self, gradients: &[Array<A, D>]) -> Result<()> {
69 if !self.initialized {
70 self.initialize(gradients)?;
71 }
72
73 if gradients.len() != self.accumulated_gradients.len() {
74 return Err(OptimError::DimensionMismatch(format!(
75 "Expected {} gradient arrays, got {}",
76 self.accumulated_gradients.len(),
77 gradients.len()
78 )));
79 }
80
81 for (acc_grad, micro_grad) in self.accumulated_gradients.iter_mut().zip(gradients.iter()) {
83 if acc_grad.raw_dim() != micro_grad.raw_dim() {
84 return Err(OptimError::DimensionMismatch(
85 "Gradient dimensions don't match".to_string(),
86 ));
87 }
88
89 Zip::from(acc_grad).and(micro_grad).for_each(|acc, µ| {
90 *acc = *acc + micro;
91 });
92 }
93
94 self.accumulation_count += 1;
95 Ok(())
96 }
97
98 pub fn is_ready(&self) -> bool {
100 self.accumulation_count >= self.target_accumulations
101 }
102
103 pub fn get_and_reset(&mut self) -> Result<Vec<Array<A, D>>> {
105 if !self.is_ready() {
106 return Err(OptimError::InvalidConfig(format!(
107 "Accumulation not ready: {}/{} steps completed",
108 self.accumulation_count, self.target_accumulations
109 )));
110 }
111
112 let mut result = self.accumulated_gradients.clone();
113
114 match self.mode {
116 AccumulationMode::Sum => {
117 }
119 AccumulationMode::Average => {
120 let scale = A::one() / A::from(self.accumulation_count).unwrap();
121 for grad in &mut result {
122 grad.mapv_inplace(|x| x * scale);
123 }
124 }
125 }
126
127 self.reset();
129
130 Ok(result)
131 }
132
133 pub fn reset(&mut self) {
135 for grad in &mut self.accumulated_gradients {
136 grad.fill(A::zero());
137 }
138 self.accumulation_count = 0;
139 }
140
141 pub fn accumulation_count(&self) -> usize {
143 self.accumulation_count
144 }
145
146 pub fn target_accumulations(&self) -> usize {
148 self.target_accumulations
149 }
150
151 pub fn set_target_accumulations(&mut self, target: usize) {
153 self.target_accumulations = target;
154 }
155
156 pub fn mode(&self) -> AccumulationMode {
158 self.mode
159 }
160
161 pub fn set_mode(&mut self, mode: AccumulationMode) {
163 self.mode = mode;
164 }
165
166 pub fn is_initialized(&self) -> bool {
168 self.initialized
169 }
170
171 pub fn progress(&self) -> f64 {
173 if self.target_accumulations == 0 {
174 1.0
175 } else {
176 self.accumulation_count as f64 / self.target_accumulations as f64
177 }
178 }
179}
180
181pub struct VariableAccumulator<A: Float, D: Dimension> {
183 accumulator: GradientAccumulator<A, D>,
185 adaptive_steps: Vec<(AdaptiveStepCondition, usize)>,
187 step_count: usize,
189}
190
191impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> VariableAccumulator<A, D> {
192 pub fn new(_initialtarget: usize, mode: AccumulationMode) -> Self {
194 Self {
195 accumulator: GradientAccumulator::new(_initialtarget, mode),
196 adaptive_steps: Vec::new(),
197 step_count: 0,
198 }
199 }
200
201 pub fn add_adaptive_rule<F>(&mut self, condition: F, accumulationsteps: usize)
203 where
204 F: Fn(usize) -> bool + 'static,
205 {
206 self.adaptive_steps
207 .push((Box::new(condition), accumulationsteps));
208 }
209
210 fn update_target(&mut self) {
212 for (condition, steps) in &self.adaptive_steps {
213 if condition(self.step_count) {
214 self.accumulator.set_target_accumulations(*steps);
215 break;
216 }
217 }
218 }
219
220 pub fn accumulate(&mut self, gradients: &[Array<A, D>]) -> Result<()> {
222 self.update_target();
223 self.accumulator.accumulate(gradients)
224 }
225
226 pub fn is_ready(&self) -> bool {
228 self.accumulator.is_ready()
229 }
230
231 pub fn get_and_step(&mut self) -> Result<Vec<Array<A, D>>> {
233 let result = self.accumulator.get_and_reset()?;
234 self.step_count += 1;
235 Ok(result)
236 }
237
238 pub fn step_count(&self) -> usize {
240 self.step_count
241 }
242
243 pub fn accumulator(&self) -> &GradientAccumulator<A, D> {
245 &self.accumulator
246 }
247
248 pub fn accumulator_mut(&mut self) -> &mut GradientAccumulator<A, D> {
250 &mut self.accumulator
251 }
252}
253
254#[derive(Debug)]
256pub struct MicroBatchTrainer<A: Float, D: Dimension> {
257 accumulator: GradientAccumulator<A, D>,
259 micro_batch_size: usize,
261 effective_batch_size: usize,
263}
264
265impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> MicroBatchTrainer<A, D> {
266 pub fn new(
268 micro_batch_size: usize,
269 effective_batch_size: usize,
270 mode: AccumulationMode,
271 ) -> Result<Self> {
272 if effective_batch_size < micro_batch_size {
273 return Err(OptimError::InvalidConfig(
274 "Effective batch _size must be >= micro batch _size".to_string(),
275 ));
276 }
277
278 let accumulation_steps = effective_batch_size / micro_batch_size;
279 let accumulator = GradientAccumulator::new(accumulation_steps, mode);
280
281 Ok(Self {
282 accumulator,
283 micro_batch_size,
284 effective_batch_size,
285 })
286 }
287
288 pub fn process_micro_batch(&mut self, gradients: &[Array<A, D>]) -> Result<()> {
290 self.accumulator.accumulate(gradients)
291 }
292
293 pub fn ready_for_step(&self) -> bool {
295 self.accumulator.is_ready()
296 }
297
298 pub fn get_accumulated_gradients(&mut self) -> Result<Vec<Array<A, D>>> {
300 self.accumulator.get_and_reset()
301 }
302
303 pub fn micro_batch_size(&self) -> usize {
305 self.micro_batch_size
306 }
307
308 pub fn effective_batch_size(&self) -> usize {
310 self.effective_batch_size
311 }
312
313 pub fn progress(&self) -> f64 {
315 self.accumulator.progress()
316 }
317
318 pub fn set_effective_batch_size(&mut self, effective_batchsize: usize) -> Result<()> {
320 if effective_batchsize < self.micro_batch_size {
321 return Err(OptimError::InvalidConfig(
322 "Effective batch _size must be >= micro batch _size".to_string(),
323 ));
324 }
325
326 self.effective_batch_size = effective_batchsize;
327 let accumulation_steps = effective_batchsize / self.micro_batch_size;
328 self.accumulator
329 .set_target_accumulations(accumulation_steps);
330 Ok(())
331 }
332}
333
334pub mod utils {
336 use super::*;
337
338 pub fn calculate_micro_batch_size(
340 total_batch_size: usize,
341 max_memory_mb: usize,
342 param_count: usize,
343 bytes_per_param: usize,
344 ) -> usize {
345 let memory_per_sample = param_count * bytes_per_param * 3; let max_samples = (max_memory_mb * 1_000_000) / memory_per_sample;
348
349 let mut micro_batch_size = max_samples.min(total_batch_size);
351 while !total_batch_size.is_multiple_of(micro_batch_size) && micro_batch_size > 1 {
352 micro_batch_size -= 1;
353 }
354
355 micro_batch_size.max(1)
356 }
357
358 pub fn calculate_accumulation_steps(
360 _total_batch_size: usize,
361 micro_batch_size: usize,
362 ) -> usize {
363 _total_batch_size.div_ceil(micro_batch_size) }
365
366 pub fn validate_config(
368 micro_batch_size: usize,
369 effective_batch_size: usize,
370 accumulation_steps: usize,
371 ) -> Result<()> {
372 if micro_batch_size == 0 {
373 return Err(OptimError::InvalidConfig(
374 "Micro batch _size must be > 0".to_string(),
375 ));
376 }
377
378 if effective_batch_size == 0 {
379 return Err(OptimError::InvalidConfig(
380 "Effective batch _size must be > 0".to_string(),
381 ));
382 }
383
384 if accumulation_steps == 0 {
385 return Err(OptimError::InvalidConfig(
386 "Accumulation _steps must be > 0".to_string(),
387 ));
388 }
389
390 if effective_batch_size != micro_batch_size * accumulation_steps {
391 return Err(OptimError::InvalidConfig(format!(
392 "Effective batch _size ({}) != micro batch _size ({}) * accumulation _steps ({})",
393 effective_batch_size, micro_batch_size, accumulation_steps
394 )));
395 }
396
397 Ok(())
398 }
399}
400
401#[cfg(test)]
402mod tests {
403 use super::*;
404 use approx::assert_relative_eq;
405 use scirs2_core::ndarray::Array1;
406
407 #[test]
408 fn test_gradient_accumulator_sum() {
409 let mut accumulator = GradientAccumulator::new(3, AccumulationMode::Sum);
410
411 let grad1 = vec![Array1::from_vec(vec![1.0, 2.0, 3.0])];
413 accumulator.accumulate(&grad1).unwrap();
414 assert!(!accumulator.is_ready());
415
416 let grad2 = vec![Array1::from_vec(vec![2.0, 3.0, 4.0])];
418 accumulator.accumulate(&grad2).unwrap();
419 assert!(!accumulator.is_ready());
420
421 let grad3 = vec![Array1::from_vec(vec![1.0, 1.0, 1.0])];
423 accumulator.accumulate(&grad3).unwrap();
424 assert!(accumulator.is_ready());
425
426 let result = accumulator.get_and_reset().unwrap();
428 assert_eq!(result.len(), 1);
429 assert_eq!(result[0].as_slice().unwrap(), &[4.0, 6.0, 8.0]); assert!(!accumulator.is_ready());
433 assert_eq!(accumulator.accumulation_count(), 0);
434 }
435
436 #[test]
437 fn test_gradient_accumulator_average() {
438 let mut accumulator = GradientAccumulator::new(2, AccumulationMode::Average);
439
440 let grad1 = vec![Array1::from_vec(vec![2.0, 4.0])];
441 let grad2 = vec![Array1::from_vec(vec![4.0, 2.0])];
442
443 accumulator.accumulate(&grad1).unwrap();
444 accumulator.accumulate(&grad2).unwrap();
445
446 let result = accumulator.get_and_reset().unwrap();
447 assert_eq!(result[0].as_slice().unwrap(), &[3.0, 3.0]); }
449
450 #[test]
451 fn test_variable_accumulator() {
452 let mut var_accumulator = VariableAccumulator::new(2, AccumulationMode::Sum);
453
454 var_accumulator.add_adaptive_rule(|step| step > 5, 4);
456
457 let grad = vec![Array1::from_vec(vec![1.0])];
459 var_accumulator.accumulate(&grad).unwrap();
460 var_accumulator.accumulate(&grad).unwrap();
461 assert!(var_accumulator.is_ready());
462
463 let _result = var_accumulator.get_and_step().unwrap();
464
465 for _ in 0..6 {
467 var_accumulator.accumulate(&grad).unwrap();
468 var_accumulator.accumulate(&grad).unwrap();
469 if var_accumulator.is_ready() {
470 var_accumulator.get_and_step().unwrap();
471 }
472 }
473
474 assert_eq!(var_accumulator.accumulator().target_accumulations(), 4);
476 }
477
478 #[test]
479 fn test_micro_batch_trainer() {
480 let mut trainer = MicroBatchTrainer::new(
481 2, 6, AccumulationMode::Sum,
484 )
485 .unwrap();
486
487 assert_eq!(trainer.micro_batch_size(), 2);
488 assert_eq!(trainer.effective_batch_size(), 6);
489
490 let grad = vec![Array1::from_vec(vec![1.0, 1.0])];
491
492 trainer.process_micro_batch(&grad).unwrap();
494 assert!(!trainer.ready_for_step());
495
496 trainer.process_micro_batch(&grad).unwrap();
497 assert!(!trainer.ready_for_step());
498
499 trainer.process_micro_batch(&grad).unwrap();
500 assert!(trainer.ready_for_step());
501
502 let result = trainer.get_accumulated_gradients().unwrap();
503 assert_eq!(result[0].as_slice().unwrap(), &[3.0, 3.0]); }
505
506 #[test]
507 fn test_calculate_micro_batch_size() {
508 let micro_batch = utils::calculate_micro_batch_size(
509 128, 100, 1000, 8, );
514
515 assert!(128 % micro_batch == 0);
517 assert!(micro_batch > 0);
518 }
519
520 #[test]
521 fn test_accumulation_steps_calculation() {
522 assert_eq!(utils::calculate_accumulation_steps(128, 32), 4);
523 assert_eq!(utils::calculate_accumulation_steps(100, 32), 4); assert_eq!(utils::calculate_accumulation_steps(96, 32), 3);
525 }
526
527 #[test]
528 fn test_config_validation() {
529 utils::validate_config(32, 128, 4).unwrap();
531
532 assert!(utils::validate_config(0, 128, 4).is_err());
534
535 assert!(utils::validate_config(32, 100, 4).is_err());
537 }
538
539 #[test]
540 fn test_accumulator_progress() {
541 let mut accumulator = GradientAccumulator::new(4, AccumulationMode::Sum);
542
543 assert_relative_eq!(accumulator.progress(), 0.0);
544
545 let grad = vec![Array1::from_vec(vec![1.0])];
546
547 accumulator.accumulate(&grad).unwrap();
548 assert_relative_eq!(accumulator.progress(), 0.25);
549
550 accumulator.accumulate(&grad).unwrap();
551 assert_relative_eq!(accumulator.progress(), 0.5);
552
553 accumulator.accumulate(&grad).unwrap();
554 assert_relative_eq!(accumulator.progress(), 0.75);
555
556 accumulator.accumulate(&grad).unwrap();
557 assert_relative_eq!(accumulator.progress(), 1.0);
558 }
559
560 #[test]
561 fn test_dimension_mismatch_error() {
562 let mut accumulator = GradientAccumulator::new(2, AccumulationMode::Sum);
563
564 let grad1 = vec![Array1::from_vec(vec![1.0, 2.0])];
565 accumulator.accumulate(&grad1).unwrap();
566
567 let grad2 = vec![Array1::from_vec(vec![1.0, 2.0, 3.0])];
569 assert!(accumulator.accumulate(&grad2).is_err());
570
571 let grad3 = vec![
573 Array1::from_vec(vec![1.0, 2.0]),
574 Array1::from_vec(vec![3.0, 4.0]),
575 ];
576 assert!(accumulator.accumulate(&grad3).is_err());
577 }
578
579 #[test]
580 fn test_get_before_ready_error() {
581 let mut accumulator = GradientAccumulator::new(3, AccumulationMode::Sum);
582
583 let grad = vec![Array1::from_vec(vec![1.0])];
584 accumulator.accumulate(&grad).unwrap();
585
586 assert!(accumulator.get_and_reset().is_err());
588 }
589}