optirs_core/
memory_efficient_optimizer.rs

1//! Memory-efficient optimizer operations
2//!
3//! This module provides memory-efficient optimization for very large models
4//! through gradient accumulation, chunked processing, and memory usage estimation.
5//!
6//! # Features
7//!
8//! - Gradient accumulation to reduce memory pressure
9//! - Chunked parameter processing for large models
10//! - Memory usage estimation and recommendations
11//! - Streaming gradient computation
12//!
13//! # Performance
14//!
15//! Enables optimization of models with billions of parameters through efficient memory management.
16
17use scirs2_core::ndarray::{s, Array1, ArrayView1, Ix1, ScalarOperand};
18use scirs2_core::numeric::{Float, Zero};
19use std::fmt::Debug;
20
21use crate::error::Result;
22use crate::optimizers::Optimizer;
23
24/// Gradient accumulator for memory-efficient training
25///
26/// Accumulates gradients over multiple micro-batches before applying updates,
27/// reducing memory requirements for large batch training.
28///
29/// # Examples
30///
31/// ```
32/// use scirs2_core::ndarray::Array1;
33/// use optirs_core::memory_efficient_optimizer::GradientAccumulator;
34///
35/// let mut accumulator = GradientAccumulator::<f32>::new(1000);
36///
37/// // Accumulate gradients from 4 micro-batches
38/// for _ in 0..4 {
39///     let micro_batch_grads = Array1::from_elem(1000, 0.1);
40///     accumulator.accumulate(&micro_batch_grads.view()).unwrap();
41/// }
42///
43/// // Get averaged gradients
44/// let avg_grads = accumulator.average().unwrap();
45/// ```
46pub struct GradientAccumulator<A: Float> {
47    accumulated: Array1<A>,
48    count: usize,
49}
50
51impl<A: Float + ScalarOperand + Debug + Zero> GradientAccumulator<A> {
52    /// Creates a new gradient accumulator
53    ///
54    /// # Arguments
55    ///
56    /// * `size` - Size of gradient vectors
57    pub fn new(size: usize) -> Self {
58        Self {
59            accumulated: Array1::zeros(size),
60            count: 0,
61        }
62    }
63
64    /// Accumulate a gradient vector
65    ///
66    /// # Arguments
67    ///
68    /// * `gradients` - Gradients to accumulate
69    pub fn accumulate(&mut self, gradients: &ArrayView1<A>) -> Result<()> {
70        if gradients.len() != self.accumulated.len() {
71            return Err(crate::error::OptimError::DimensionMismatch(format!(
72                "Gradient size ({}) doesn't match accumulator size ({})",
73                gradients.len(),
74                self.accumulated.len()
75            )));
76        }
77
78        self.accumulated = &self.accumulated + gradients;
79        self.count += 1;
80
81        Ok(())
82    }
83
84    /// Get the number of accumulated gradients
85    pub fn count(&self) -> usize {
86        self.count
87    }
88
89    /// Compute the average of accumulated gradients
90    ///
91    /// Returns the averaged gradients and resets the accumulator.
92    pub fn average(&mut self) -> Result<Array1<A>> {
93        if self.count == 0 {
94            return Err(crate::error::OptimError::InvalidConfig(
95                "No gradients accumulated".to_string(),
96            ));
97        }
98
99        let scale = A::from(self.count).unwrap();
100        let averaged = &self.accumulated / scale;
101
102        // Reset accumulator
103        self.reset();
104
105        Ok(averaged)
106    }
107
108    /// Reset the accumulator
109    pub fn reset(&mut self) {
110        self.accumulated.fill(A::zero());
111        self.count = 0;
112    }
113
114    /// Check if accumulator has reached target count
115    ///
116    /// # Arguments
117    ///
118    /// * `target` - Target number of accumulations
119    pub fn is_ready(&self, target: usize) -> bool {
120        self.count >= target
121    }
122}
123
124/// Chunked optimizer for processing large parameter arrays in chunks
125///
126/// Enables optimization of very large models by processing parameters
127/// in manageable chunks, reducing peak memory usage.
128pub struct ChunkedOptimizer<O, A>
129where
130    O: Optimizer<A, Ix1> + Clone,
131    A: Float + ScalarOperand + Debug,
132{
133    base_optimizer: O,
134    chunk_size: usize,
135    _phantom: std::marker::PhantomData<A>,
136}
137
138impl<O, A> ChunkedOptimizer<O, A>
139where
140    O: Optimizer<A, Ix1> + Clone,
141    A: Float + ScalarOperand + Debug,
142{
143    /// Creates a new chunked optimizer
144    ///
145    /// # Arguments
146    ///
147    /// * `base_optimizer` - Base optimizer to use for each chunk
148    /// * `chunk_size` - Size of each chunk (default: 1M elements)
149    pub fn new(base_optimizer: O, chunk_size: Option<usize>) -> Self {
150        let chunk_size = chunk_size.unwrap_or(1_000_000);
151
152        Self {
153            base_optimizer,
154            chunk_size,
155            _phantom: std::marker::PhantomData,
156        }
157    }
158
159    /// Process parameters in chunks
160    ///
161    /// # Arguments
162    ///
163    /// * `params` - Full parameter array
164    /// * `gradients` - Full gradient array
165    ///
166    /// # Returns
167    ///
168    /// Updated parameters
169    pub fn step_chunked(&mut self, params: &Array1<A>, gradients: &Array1<A>) -> Result<Array1<A>> {
170        if params.len() != gradients.len() {
171            return Err(crate::error::OptimError::DimensionMismatch(format!(
172                "Parameters ({}) and gradients ({}) must have same size",
173                params.len(),
174                gradients.len()
175            )));
176        }
177
178        let total_size = params.len();
179        let mut updated = Array1::zeros(total_size);
180
181        // Process in chunks
182        let num_chunks = total_size.div_ceil(self.chunk_size);
183
184        for chunk_idx in 0..num_chunks {
185            let start = chunk_idx * self.chunk_size;
186            let end = (start + self.chunk_size).min(total_size);
187
188            // Extract chunk views
189            let params_chunk = params.slice(s![start..end]).to_owned();
190            let grads_chunk = gradients.slice(s![start..end]).to_owned();
191
192            // Update chunk
193            let updated_chunk = self.base_optimizer.step(&params_chunk, &grads_chunk)?;
194
195            // Copy back to result
196            updated.slice_mut(s![start..end]).assign(&updated_chunk);
197        }
198
199        Ok(updated)
200    }
201
202    /// Get the chunk size
203    pub fn chunk_size(&self) -> usize {
204        self.chunk_size
205    }
206
207    /// Calculate number of chunks for given size
208    pub fn num_chunks(&self, total_size: usize) -> usize {
209        total_size.div_ceil(self.chunk_size)
210    }
211}
212
213/// Memory usage estimator for optimizers
214///
215/// Provides utilities for estimating memory requirements and recommending
216/// optimal configurations for different optimizer types.
217pub struct MemoryUsageEstimator;
218
219impl MemoryUsageEstimator {
220    /// Estimate memory usage for SGD without momentum
221    ///
222    /// # Arguments
223    ///
224    /// * `num_params` - Number of parameters
225    /// * `dtype_size` - Size of data type in bytes (4 for f32, 8 for f64)
226    ///
227    /// # Returns
228    ///
229    /// Estimated memory usage in bytes
230    pub fn sgd(num_params: usize, dtype_size: usize) -> usize {
231        // Parameters + gradients
232        num_params * dtype_size * 2
233    }
234
235    /// Estimate memory usage for SGD with momentum
236    ///
237    /// # Arguments
238    ///
239    /// * `num_params` - Number of parameters
240    /// * `dtype_size` - Size of data type in bytes (4 for f32, 8 for f64)
241    ///
242    /// # Returns
243    ///
244    /// Estimated memory usage in bytes
245    pub fn sgd_with_momentum(num_params: usize, dtype_size: usize) -> usize {
246        // Parameters + gradients + velocity
247        num_params * dtype_size * 3
248    }
249
250    /// Estimate memory usage for Adam optimizer
251    ///
252    /// # Arguments
253    ///
254    /// * `num_params` - Number of parameters
255    /// * `dtype_size` - Size of data type in bytes (4 for f32, 8 for f64)
256    ///
257    /// # Returns
258    ///
259    /// Estimated memory usage in bytes
260    pub fn adam(num_params: usize, dtype_size: usize) -> usize {
261        // Parameters + gradients + first moment + second moment
262        num_params * dtype_size * 4
263    }
264
265    /// Recommend chunk size based on available memory
266    ///
267    /// # Arguments
268    ///
269    /// * `total_params` - Total number of parameters
270    /// * `available_memory_bytes` - Available memory in bytes
271    /// * `dtype_size` - Size of data type in bytes (4 for f32, 8 for f64)
272    /// * `optimizer_state_multiplier` - Memory multiplier for optimizer state
273    ///
274    /// # Returns
275    ///
276    /// Recommended chunk size
277    pub fn recommend_chunk_size(
278        total_params: usize,
279        available_memory_bytes: usize,
280        dtype_size: usize,
281        optimizer_state_multiplier: usize,
282    ) -> usize {
283        let memory_per_param = dtype_size * optimizer_state_multiplier;
284        let max_params = available_memory_bytes / memory_per_param;
285
286        // Use 80% of available memory to leave headroom
287        let safe_params = (max_params * 80) / 100;
288
289        safe_params.min(total_params).max(1024)
290    }
291
292    /// Get recommended accumulation steps for given batch size
293    ///
294    /// # Arguments
295    ///
296    /// * `target_batch_size` - Desired effective batch size
297    /// * `max_micro_batch_size` - Maximum micro-batch that fits in memory
298    ///
299    /// # Returns
300    ///
301    /// Number of gradient accumulation steps
302    pub fn recommend_accumulation_steps(
303        target_batch_size: usize,
304        max_micro_batch_size: usize,
305    ) -> usize {
306        target_batch_size.div_ceil(max_micro_batch_size)
307    }
308
309    /// Estimate peak memory usage during training
310    ///
311    /// # Arguments
312    ///
313    /// * `num_params` - Number of parameters
314    /// * `batch_size` - Batch size
315    /// * `sequence_length` - Sequence length (for transformers, 1 otherwise)
316    /// * `dtype_size` - Size of data type in bytes
317    /// * `optimizer_type` - Type of optimizer ("sgd", "adam", etc.)
318    ///
319    /// # Returns
320    ///
321    /// Estimated peak memory in bytes
322    pub fn estimate_peak_memory(
323        num_params: usize,
324        batch_size: usize,
325        sequence_length: usize,
326        dtype_size: usize,
327        optimizer_type: &str,
328    ) -> usize {
329        // Model parameters
330        let param_memory = num_params * dtype_size;
331
332        // Gradients
333        let grad_memory = num_params * dtype_size;
334
335        // Optimizer state
336        let optimizer_memory = match optimizer_type {
337            "sgd" => num_params * dtype_size,
338            "adam" | "adamw" => num_params * dtype_size * 2,
339            _ => num_params * dtype_size,
340        };
341
342        // Activations (rough estimate: batch_size * sequence_length * hidden_dim)
343        let hidden_dim = (num_params as f64).sqrt() as usize;
344        let activation_memory = batch_size * sequence_length * hidden_dim * dtype_size;
345
346        param_memory + grad_memory + optimizer_memory + activation_memory
347    }
348}
349
350#[cfg(test)]
351mod tests {
352    use super::*;
353    use crate::optimizers::SGD;
354    use approx::assert_relative_eq;
355
356    #[test]
357    fn test_gradient_accumulator() {
358        let mut accumulator = GradientAccumulator::<f32>::new(100);
359
360        // Accumulate some gradients
361        let grad1 = Array1::from_elem(100, 1.0);
362        let grad2 = Array1::from_elem(100, 2.0);
363
364        accumulator.accumulate(&grad1.view()).unwrap();
365        accumulator.accumulate(&grad2.view()).unwrap();
366
367        assert_eq!(accumulator.count(), 2);
368        assert!(accumulator.is_ready(2));
369
370        // Get average
371        let avg = accumulator.average().unwrap();
372        assert_relative_eq!(avg[0], 1.5, epsilon = 1e-6);
373
374        // After average, accumulator should be reset
375        assert_eq!(accumulator.count(), 0);
376    }
377
378    #[test]
379    fn test_chunked_optimizer() {
380        let optimizer = SGD::new(0.01);
381        let mut chunked_opt = ChunkedOptimizer::new(optimizer, Some(10));
382
383        let params = Array1::from_vec((0..25).map(|i| i as f32).collect());
384        let gradients = Array1::from_elem(25, 0.1);
385
386        let updated = chunked_opt.step_chunked(&params, &gradients).unwrap();
387
388        // Verify updates
389        assert_eq!(updated.len(), 25);
390        assert_relative_eq!(updated[0], 0.0 - 0.01 * 0.1, epsilon = 1e-6);
391
392        // Check number of chunks
393        assert_eq!(chunked_opt.num_chunks(25), 3);
394    }
395
396    #[test]
397    fn test_memory_estimator_sgd() {
398        // SGD for 1M parameters (f32)
399        let mem = MemoryUsageEstimator::sgd(1_000_000, 4);
400        assert_eq!(mem, 8_000_000); // 8 MB
401
402        // SGD with momentum
403        let mem = MemoryUsageEstimator::sgd_with_momentum(1_000_000, 4);
404        assert_eq!(mem, 12_000_000); // 12 MB
405    }
406
407    #[test]
408    fn test_memory_estimator_adam() {
409        // Adam for 1M parameters (f32)
410        let mem = MemoryUsageEstimator::adam(1_000_000, 4);
411        assert_eq!(mem, 16_000_000); // 16 MB
412    }
413
414    #[test]
415    fn test_recommend_chunk_size() {
416        // 1GB available, f32, Adam optimizer
417        let chunk_size = MemoryUsageEstimator::recommend_chunk_size(
418            100_000_000,   // 100M total params
419            1_000_000_000, // 1GB available
420            4,             // f32
421            4,             // Adam state multiplier
422        );
423
424        // Should be around 50M params (80% of 62.5M that fits in 1GB)
425        assert!(chunk_size > 40_000_000);
426        assert!(chunk_size < 60_000_000);
427    }
428
429    #[test]
430    fn test_recommend_accumulation_steps() {
431        let steps = MemoryUsageEstimator::recommend_accumulation_steps(128, 32);
432        assert_eq!(steps, 4);
433
434        let steps = MemoryUsageEstimator::recommend_accumulation_steps(100, 32);
435        assert_eq!(steps, 4); // Rounds up
436    }
437
438    #[test]
439    fn test_estimate_peak_memory() {
440        let peak = MemoryUsageEstimator::estimate_peak_memory(
441            10_000_000, // 10M params
442            32,         // batch size
443            512,        // sequence length
444            4,          // f32
445            "adam",
446        );
447
448        // Should be substantial (model + optimizer + activations)
449        assert!(peak > 100_000_000); // > 100MB
450    }
451}