optirs_core/memory_efficient_optimizer.rs
1//! Memory-efficient optimizer operations
2//!
3//! This module provides memory-efficient optimization for very large models
4//! through gradient accumulation, chunked processing, and memory usage estimation.
5//!
6//! # Features
7//!
8//! - Gradient accumulation to reduce memory pressure
9//! - Chunked parameter processing for large models
10//! - Memory usage estimation and recommendations
11//! - Streaming gradient computation
12//!
13//! # Performance
14//!
15//! Enables optimization of models with billions of parameters through efficient memory management.
16
17use scirs2_core::ndarray::{s, Array1, ArrayView1, Ix1, ScalarOperand};
18use scirs2_core::numeric::{Float, Zero};
19use std::fmt::Debug;
20
21use crate::error::Result;
22use crate::optimizers::Optimizer;
23
24/// Gradient accumulator for memory-efficient training
25///
26/// Accumulates gradients over multiple micro-batches before applying updates,
27/// reducing memory requirements for large batch training.
28///
29/// # Examples
30///
31/// ```
32/// use scirs2_core::ndarray::Array1;
33/// use optirs_core::memory_efficient_optimizer::GradientAccumulator;
34///
35/// let mut accumulator = GradientAccumulator::<f32>::new(1000);
36///
37/// // Accumulate gradients from 4 micro-batches
38/// for _ in 0..4 {
39/// let micro_batch_grads = Array1::from_elem(1000, 0.1);
40/// accumulator.accumulate(µ_batch_grads.view()).unwrap();
41/// }
42///
43/// // Get averaged gradients
44/// let avg_grads = accumulator.average().unwrap();
45/// ```
46pub struct GradientAccumulator<A: Float> {
47 accumulated: Array1<A>,
48 count: usize,
49}
50
51impl<A: Float + ScalarOperand + Debug + Zero> GradientAccumulator<A> {
52 /// Creates a new gradient accumulator
53 ///
54 /// # Arguments
55 ///
56 /// * `size` - Size of gradient vectors
57 pub fn new(size: usize) -> Self {
58 Self {
59 accumulated: Array1::zeros(size),
60 count: 0,
61 }
62 }
63
64 /// Accumulate a gradient vector
65 ///
66 /// # Arguments
67 ///
68 /// * `gradients` - Gradients to accumulate
69 pub fn accumulate(&mut self, gradients: &ArrayView1<A>) -> Result<()> {
70 if gradients.len() != self.accumulated.len() {
71 return Err(crate::error::OptimError::DimensionMismatch(format!(
72 "Gradient size ({}) doesn't match accumulator size ({})",
73 gradients.len(),
74 self.accumulated.len()
75 )));
76 }
77
78 self.accumulated = &self.accumulated + gradients;
79 self.count += 1;
80
81 Ok(())
82 }
83
84 /// Get the number of accumulated gradients
85 pub fn count(&self) -> usize {
86 self.count
87 }
88
89 /// Compute the average of accumulated gradients
90 ///
91 /// Returns the averaged gradients and resets the accumulator.
92 pub fn average(&mut self) -> Result<Array1<A>> {
93 if self.count == 0 {
94 return Err(crate::error::OptimError::InvalidConfig(
95 "No gradients accumulated".to_string(),
96 ));
97 }
98
99 let scale = A::from(self.count).unwrap();
100 let averaged = &self.accumulated / scale;
101
102 // Reset accumulator
103 self.reset();
104
105 Ok(averaged)
106 }
107
108 /// Reset the accumulator
109 pub fn reset(&mut self) {
110 self.accumulated.fill(A::zero());
111 self.count = 0;
112 }
113
114 /// Check if accumulator has reached target count
115 ///
116 /// # Arguments
117 ///
118 /// * `target` - Target number of accumulations
119 pub fn is_ready(&self, target: usize) -> bool {
120 self.count >= target
121 }
122}
123
124/// Chunked optimizer for processing large parameter arrays in chunks
125///
126/// Enables optimization of very large models by processing parameters
127/// in manageable chunks, reducing peak memory usage.
128pub struct ChunkedOptimizer<O, A>
129where
130 O: Optimizer<A, Ix1> + Clone,
131 A: Float + ScalarOperand + Debug,
132{
133 base_optimizer: O,
134 chunk_size: usize,
135 _phantom: std::marker::PhantomData<A>,
136}
137
138impl<O, A> ChunkedOptimizer<O, A>
139where
140 O: Optimizer<A, Ix1> + Clone,
141 A: Float + ScalarOperand + Debug,
142{
143 /// Creates a new chunked optimizer
144 ///
145 /// # Arguments
146 ///
147 /// * `base_optimizer` - Base optimizer to use for each chunk
148 /// * `chunk_size` - Size of each chunk (default: 1M elements)
149 pub fn new(base_optimizer: O, chunk_size: Option<usize>) -> Self {
150 let chunk_size = chunk_size.unwrap_or(1_000_000);
151
152 Self {
153 base_optimizer,
154 chunk_size,
155 _phantom: std::marker::PhantomData,
156 }
157 }
158
159 /// Process parameters in chunks
160 ///
161 /// # Arguments
162 ///
163 /// * `params` - Full parameter array
164 /// * `gradients` - Full gradient array
165 ///
166 /// # Returns
167 ///
168 /// Updated parameters
169 pub fn step_chunked(&mut self, params: &Array1<A>, gradients: &Array1<A>) -> Result<Array1<A>> {
170 if params.len() != gradients.len() {
171 return Err(crate::error::OptimError::DimensionMismatch(format!(
172 "Parameters ({}) and gradients ({}) must have same size",
173 params.len(),
174 gradients.len()
175 )));
176 }
177
178 let total_size = params.len();
179 let mut updated = Array1::zeros(total_size);
180
181 // Process in chunks
182 let num_chunks = total_size.div_ceil(self.chunk_size);
183
184 for chunk_idx in 0..num_chunks {
185 let start = chunk_idx * self.chunk_size;
186 let end = (start + self.chunk_size).min(total_size);
187
188 // Extract chunk views
189 let params_chunk = params.slice(s![start..end]).to_owned();
190 let grads_chunk = gradients.slice(s![start..end]).to_owned();
191
192 // Update chunk
193 let updated_chunk = self.base_optimizer.step(¶ms_chunk, &grads_chunk)?;
194
195 // Copy back to result
196 updated.slice_mut(s![start..end]).assign(&updated_chunk);
197 }
198
199 Ok(updated)
200 }
201
202 /// Get the chunk size
203 pub fn chunk_size(&self) -> usize {
204 self.chunk_size
205 }
206
207 /// Calculate number of chunks for given size
208 pub fn num_chunks(&self, total_size: usize) -> usize {
209 total_size.div_ceil(self.chunk_size)
210 }
211}
212
213/// Memory usage estimator for optimizers
214///
215/// Provides utilities for estimating memory requirements and recommending
216/// optimal configurations for different optimizer types.
217pub struct MemoryUsageEstimator;
218
219impl MemoryUsageEstimator {
220 /// Estimate memory usage for SGD without momentum
221 ///
222 /// # Arguments
223 ///
224 /// * `num_params` - Number of parameters
225 /// * `dtype_size` - Size of data type in bytes (4 for f32, 8 for f64)
226 ///
227 /// # Returns
228 ///
229 /// Estimated memory usage in bytes
230 pub fn sgd(num_params: usize, dtype_size: usize) -> usize {
231 // Parameters + gradients
232 num_params * dtype_size * 2
233 }
234
235 /// Estimate memory usage for SGD with momentum
236 ///
237 /// # Arguments
238 ///
239 /// * `num_params` - Number of parameters
240 /// * `dtype_size` - Size of data type in bytes (4 for f32, 8 for f64)
241 ///
242 /// # Returns
243 ///
244 /// Estimated memory usage in bytes
245 pub fn sgd_with_momentum(num_params: usize, dtype_size: usize) -> usize {
246 // Parameters + gradients + velocity
247 num_params * dtype_size * 3
248 }
249
250 /// Estimate memory usage for Adam optimizer
251 ///
252 /// # Arguments
253 ///
254 /// * `num_params` - Number of parameters
255 /// * `dtype_size` - Size of data type in bytes (4 for f32, 8 for f64)
256 ///
257 /// # Returns
258 ///
259 /// Estimated memory usage in bytes
260 pub fn adam(num_params: usize, dtype_size: usize) -> usize {
261 // Parameters + gradients + first moment + second moment
262 num_params * dtype_size * 4
263 }
264
265 /// Recommend chunk size based on available memory
266 ///
267 /// # Arguments
268 ///
269 /// * `total_params` - Total number of parameters
270 /// * `available_memory_bytes` - Available memory in bytes
271 /// * `dtype_size` - Size of data type in bytes (4 for f32, 8 for f64)
272 /// * `optimizer_state_multiplier` - Memory multiplier for optimizer state
273 ///
274 /// # Returns
275 ///
276 /// Recommended chunk size
277 pub fn recommend_chunk_size(
278 total_params: usize,
279 available_memory_bytes: usize,
280 dtype_size: usize,
281 optimizer_state_multiplier: usize,
282 ) -> usize {
283 let memory_per_param = dtype_size * optimizer_state_multiplier;
284 let max_params = available_memory_bytes / memory_per_param;
285
286 // Use 80% of available memory to leave headroom
287 let safe_params = (max_params * 80) / 100;
288
289 safe_params.min(total_params).max(1024)
290 }
291
292 /// Get recommended accumulation steps for given batch size
293 ///
294 /// # Arguments
295 ///
296 /// * `target_batch_size` - Desired effective batch size
297 /// * `max_micro_batch_size` - Maximum micro-batch that fits in memory
298 ///
299 /// # Returns
300 ///
301 /// Number of gradient accumulation steps
302 pub fn recommend_accumulation_steps(
303 target_batch_size: usize,
304 max_micro_batch_size: usize,
305 ) -> usize {
306 target_batch_size.div_ceil(max_micro_batch_size)
307 }
308
309 /// Estimate peak memory usage during training
310 ///
311 /// # Arguments
312 ///
313 /// * `num_params` - Number of parameters
314 /// * `batch_size` - Batch size
315 /// * `sequence_length` - Sequence length (for transformers, 1 otherwise)
316 /// * `dtype_size` - Size of data type in bytes
317 /// * `optimizer_type` - Type of optimizer ("sgd", "adam", etc.)
318 ///
319 /// # Returns
320 ///
321 /// Estimated peak memory in bytes
322 pub fn estimate_peak_memory(
323 num_params: usize,
324 batch_size: usize,
325 sequence_length: usize,
326 dtype_size: usize,
327 optimizer_type: &str,
328 ) -> usize {
329 // Model parameters
330 let param_memory = num_params * dtype_size;
331
332 // Gradients
333 let grad_memory = num_params * dtype_size;
334
335 // Optimizer state
336 let optimizer_memory = match optimizer_type {
337 "sgd" => num_params * dtype_size,
338 "adam" | "adamw" => num_params * dtype_size * 2,
339 _ => num_params * dtype_size,
340 };
341
342 // Activations (rough estimate: batch_size * sequence_length * hidden_dim)
343 let hidden_dim = (num_params as f64).sqrt() as usize;
344 let activation_memory = batch_size * sequence_length * hidden_dim * dtype_size;
345
346 param_memory + grad_memory + optimizer_memory + activation_memory
347 }
348}
349
350#[cfg(test)]
351mod tests {
352 use super::*;
353 use crate::optimizers::SGD;
354 use approx::assert_relative_eq;
355
356 #[test]
357 fn test_gradient_accumulator() {
358 let mut accumulator = GradientAccumulator::<f32>::new(100);
359
360 // Accumulate some gradients
361 let grad1 = Array1::from_elem(100, 1.0);
362 let grad2 = Array1::from_elem(100, 2.0);
363
364 accumulator.accumulate(&grad1.view()).unwrap();
365 accumulator.accumulate(&grad2.view()).unwrap();
366
367 assert_eq!(accumulator.count(), 2);
368 assert!(accumulator.is_ready(2));
369
370 // Get average
371 let avg = accumulator.average().unwrap();
372 assert_relative_eq!(avg[0], 1.5, epsilon = 1e-6);
373
374 // After average, accumulator should be reset
375 assert_eq!(accumulator.count(), 0);
376 }
377
378 #[test]
379 fn test_chunked_optimizer() {
380 let optimizer = SGD::new(0.01);
381 let mut chunked_opt = ChunkedOptimizer::new(optimizer, Some(10));
382
383 let params = Array1::from_vec((0..25).map(|i| i as f32).collect());
384 let gradients = Array1::from_elem(25, 0.1);
385
386 let updated = chunked_opt.step_chunked(¶ms, &gradients).unwrap();
387
388 // Verify updates
389 assert_eq!(updated.len(), 25);
390 assert_relative_eq!(updated[0], 0.0 - 0.01 * 0.1, epsilon = 1e-6);
391
392 // Check number of chunks
393 assert_eq!(chunked_opt.num_chunks(25), 3);
394 }
395
396 #[test]
397 fn test_memory_estimator_sgd() {
398 // SGD for 1M parameters (f32)
399 let mem = MemoryUsageEstimator::sgd(1_000_000, 4);
400 assert_eq!(mem, 8_000_000); // 8 MB
401
402 // SGD with momentum
403 let mem = MemoryUsageEstimator::sgd_with_momentum(1_000_000, 4);
404 assert_eq!(mem, 12_000_000); // 12 MB
405 }
406
407 #[test]
408 fn test_memory_estimator_adam() {
409 // Adam for 1M parameters (f32)
410 let mem = MemoryUsageEstimator::adam(1_000_000, 4);
411 assert_eq!(mem, 16_000_000); // 16 MB
412 }
413
414 #[test]
415 fn test_recommend_chunk_size() {
416 // 1GB available, f32, Adam optimizer
417 let chunk_size = MemoryUsageEstimator::recommend_chunk_size(
418 100_000_000, // 100M total params
419 1_000_000_000, // 1GB available
420 4, // f32
421 4, // Adam state multiplier
422 );
423
424 // Should be around 50M params (80% of 62.5M that fits in 1GB)
425 assert!(chunk_size > 40_000_000);
426 assert!(chunk_size < 60_000_000);
427 }
428
429 #[test]
430 fn test_recommend_accumulation_steps() {
431 let steps = MemoryUsageEstimator::recommend_accumulation_steps(128, 32);
432 assert_eq!(steps, 4);
433
434 let steps = MemoryUsageEstimator::recommend_accumulation_steps(100, 32);
435 assert_eq!(steps, 4); // Rounds up
436 }
437
438 #[test]
439 fn test_estimate_peak_memory() {
440 let peak = MemoryUsageEstimator::estimate_peak_memory(
441 10_000_000, // 10M params
442 32, // batch size
443 512, // sequence length
444 4, // f32
445 "adam",
446 );
447
448 // Should be substantial (model + optimizer + activations)
449 assert!(peak > 100_000_000); // > 100MB
450 }
451}