ghostflow_nn/
zero_optimizer.rs

1//! ZeRO Optimizer (Zero Redundancy Optimizer)
2//!
3//! Implements memory-efficient distributed training:
4//! - ZeRO Stage 1: Optimizer state partitioning
5//! - ZeRO Stage 2: Gradient partitioning
6//! - ZeRO Stage 3: Parameter partitioning
7//! - ZeRO-Offload: CPU/NVMe offloading
8//! - Communication optimization
9
10use ghostflow_core::Tensor;
11use std::collections::HashMap;
12
13/// ZeRO stage configuration
14#[derive(Debug, Clone, Copy, PartialEq)]
15pub enum ZeRoStage {
16    /// Stage 1: Optimizer state partitioning
17    Stage1,
18    /// Stage 2: Gradient partitioning
19    Stage2,
20    /// Stage 3: Parameter partitioning
21    Stage3,
22}
23
24/// ZeRO configuration
25#[derive(Debug, Clone)]
26pub struct ZeRoConfig {
27    /// ZeRO stage
28    pub stage: ZeRoStage,
29    /// Number of processes/GPUs
30    pub world_size: usize,
31    /// Current process rank
32    pub rank: usize,
33    /// Enable CPU offloading
34    pub cpu_offload: bool,
35    /// Enable NVMe offloading
36    pub nvme_offload: bool,
37    /// Overlap communication with computation
38    pub overlap_comm: bool,
39    /// Bucket size for gradient accumulation
40    pub bucket_size: usize,
41}
42
43impl Default for ZeRoConfig {
44    fn default() -> Self {
45        ZeRoConfig {
46            stage: ZeRoStage::Stage2,
47            world_size: 1,
48            rank: 0,
49            cpu_offload: false,
50            nvme_offload: false,
51            overlap_comm: true,
52            bucket_size: 25_000_000, // 25M parameters
53        }
54    }
55}
56
57impl ZeRoConfig {
58    /// Stage 1 configuration (optimizer state partitioning)
59    pub fn stage1(world_size: usize, rank: usize) -> Self {
60        ZeRoConfig {
61            stage: ZeRoStage::Stage1,
62            world_size,
63            rank,
64            ..Default::default()
65        }
66    }
67    
68    /// Stage 2 configuration (gradient partitioning)
69    pub fn stage2(world_size: usize, rank: usize) -> Self {
70        ZeRoConfig {
71            stage: ZeRoStage::Stage2,
72            world_size,
73            rank,
74            ..Default::default()
75        }
76    }
77    
78    /// Stage 3 configuration (parameter partitioning)
79    pub fn stage3(world_size: usize, rank: usize) -> Self {
80        ZeRoConfig {
81            stage: ZeRoStage::Stage3,
82            world_size,
83            rank,
84            ..Default::default()
85        }
86    }
87    
88    /// ZeRO-Offload configuration
89    pub fn with_offload(mut self, cpu: bool, nvme: bool) -> Self {
90        self.cpu_offload = cpu;
91        self.nvme_offload = nvme;
92        self
93    }
94}
95
96/// Parameter partition information
97#[derive(Debug, Clone)]
98pub struct ParameterPartition {
99    /// Parameter name
100    pub name: String,
101    /// Owner rank
102    pub owner_rank: usize,
103    /// Start index in flattened parameters
104    pub start_idx: usize,
105    /// End index in flattened parameters
106    pub end_idx: usize,
107    /// Original shape
108    pub shape: Vec<usize>,
109}
110
111/// ZeRO optimizer state
112pub struct ZeRoOptimizer {
113    config: ZeRoConfig,
114    /// Partitioned parameters (only owned parameters)
115    partitioned_params: HashMap<String, Tensor>,
116    /// Partitioned gradients
117    partitioned_grads: HashMap<String, Tensor>,
118    /// Partitioned optimizer states (momentum, variance, etc.)
119    partitioned_states: HashMap<String, HashMap<String, Tensor>>,
120    /// Parameter partition map
121    param_partitions: Vec<ParameterPartition>,
122    /// Gradient buckets for communication
123    gradient_buckets: Vec<Vec<String>>,
124    /// Communication buffer
125    comm_buffer: Vec<f32>,
126    /// CPU offload buffer
127    cpu_buffer: HashMap<String, Vec<f32>>,
128    /// Learning rate
129    learning_rate: f32,
130}
131
132impl ZeRoOptimizer {
133    /// Create new ZeRO optimizer
134    pub fn new(config: ZeRoConfig, learning_rate: f32) -> Self {
135        ZeRoOptimizer {
136            config,
137            partitioned_params: HashMap::new(),
138            partitioned_grads: HashMap::new(),
139            partitioned_states: HashMap::new(),
140            param_partitions: Vec::new(),
141            gradient_buckets: Vec::new(),
142            comm_buffer: Vec::new(),
143            cpu_buffer: HashMap::new(),
144            learning_rate,
145        }
146    }
147    
148    /// Partition parameters across ranks
149    pub fn partition_parameters(&mut self, params: &HashMap<String, Tensor>) -> Result<(), String> {
150        let total_params: usize = params.values()
151            .map(|t| t.data_f32().len())
152            .sum();
153        
154        let params_per_rank = (total_params + self.config.world_size - 1) / self.config.world_size;
155        
156        let mut current_idx = 0;
157        
158        for (name, tensor) in params {
159            let param_size = tensor.data_f32().len();
160            let start_idx = current_idx;
161            let end_idx = current_idx + param_size;
162            
163            // Determine owner rank
164            let owner_rank = current_idx / params_per_rank;
165            
166            let partition = ParameterPartition {
167                name: name.clone(),
168                owner_rank,
169                start_idx,
170                end_idx,
171                shape: tensor.dims().to_vec(),
172            };
173            
174            self.param_partitions.push(partition);
175            
176            // Store parameter if owned by this rank
177            if owner_rank == self.config.rank {
178                self.partitioned_params.insert(name.clone(), tensor.clone());
179                
180                // Initialize optimizer states
181                let dims = tensor.dims();
182                let size = tensor.data_f32().len();
183                let zeros_data = vec![0.0f32; size];
184                
185                let mut states = HashMap::new();
186                states.insert("momentum".to_string(), Tensor::from_slice(&zeros_data, dims).unwrap());
187                states.insert("variance".to_string(), Tensor::from_slice(&zeros_data, dims).unwrap());
188                self.partitioned_states.insert(name.clone(), states);
189            }
190            
191            current_idx = end_idx;
192        }
193        
194        Ok(())
195    }
196    
197    /// Partition gradients (Stage 2+)
198    pub fn partition_gradients(&mut self, grads: &HashMap<String, Tensor>) -> Result<(), String> {
199        if self.config.stage == ZeRoStage::Stage1 {
200            // Stage 1: Keep all gradients, only partition optimizer states
201            for (name, grad) in grads {
202                self.partitioned_grads.insert(name.clone(), grad.clone());
203            }
204            return Ok(());
205        }
206        
207        // Stage 2+: Partition gradients
208        for partition in &self.param_partitions {
209            if partition.owner_rank == self.config.rank {
210                if let Some(grad) = grads.get(&partition.name) {
211                    self.partitioned_grads.insert(partition.name.clone(), grad.clone());
212                }
213            }
214        }
215        
216        Ok(())
217    }
218    
219    /// Reduce-scatter gradients across ranks
220    pub fn reduce_scatter_gradients(&mut self) -> Result<(), String> {
221        // Simulate reduce-scatter operation
222        // In real implementation, this would use NCCL or similar
223        
224        let grad_names: Vec<String> = self.partitioned_grads.keys().cloned().collect();
225        
226        for name in grad_names {
227            if let Some(grad) = self.partitioned_grads.get(&name) {
228                // Simulate averaging gradients across ranks
229                let data = grad.data_f32();
230                let averaged: Vec<f32> = data.iter()
231                    .map(|&x| x / self.config.world_size as f32)
232                    .collect();
233                
234                let averaged_grad = Tensor::from_slice(&averaged, grad.dims())
235                    .map_err(|e| format!("Failed to create averaged gradient: {:?}", e))?;
236                
237                self.partitioned_grads.insert(name, averaged_grad);
238            }
239        }
240        
241        Ok(())
242    }
243    
244    /// All-gather parameters (Stage 3)
245    pub fn all_gather_parameters(&self) -> Result<HashMap<String, Tensor>, String> {
246        if self.config.stage != ZeRoStage::Stage3 {
247            return Ok(self.partitioned_params.clone());
248        }
249        
250        // Simulate all-gather operation
251        // In real implementation, this would gather parameters from all ranks
252        let mut all_params = HashMap::new();
253        
254        for (name, param) in &self.partitioned_params {
255            all_params.insert(name.clone(), param.clone());
256        }
257        
258        Ok(all_params)
259    }
260    
261    /// Optimizer step with ZeRO
262    pub fn step(&mut self) -> Result<(), String> {
263        // Update only owned parameters
264        for (name, param) in &mut self.partitioned_params {
265            if let Some(grad) = self.partitioned_grads.get(name) {
266                if let Some(states) = self.partitioned_states.get_mut(name) {
267                    // Adam-style update
268                    let beta1 = 0.9;
269                    let beta2 = 0.999;
270                    let eps = 1e-8;
271                    
272                    // Get data from states
273                    let m_data = states.get("momentum").unwrap().data_f32();
274                    let v_data = states.get("variance").unwrap().data_f32();
275                    let g_data = grad.data_f32();
276                    let p_data = param.data_f32();
277                    
278                    let mut new_m = Vec::with_capacity(m_data.len());
279                    let mut new_v = Vec::with_capacity(v_data.len());
280                    let mut new_p = Vec::with_capacity(p_data.len());
281                    
282                    for i in 0..m_data.len() {
283                        let m = beta1 * m_data[i] + (1.0 - beta1) * g_data[i];
284                        let v = beta2 * v_data[i] + (1.0 - beta2) * g_data[i] * g_data[i];
285                        let p = p_data[i] - self.learning_rate * m / (v.sqrt() + eps);
286                        
287                        new_m.push(m);
288                        new_v.push(v);
289                        new_p.push(p);
290                    }
291                    
292                    // Get dims before updating
293                    let m_dims = states.get("momentum").unwrap().dims().to_vec();
294                    let v_dims = states.get("variance").unwrap().dims().to_vec();
295                    let p_dims = param.dims().to_vec();
296                    
297                    // Update states
298                    states.insert("momentum".to_string(), Tensor::from_slice(&new_m, &m_dims)
299                        .map_err(|e| format!("Failed to create momentum tensor: {:?}", e))?);
300                    states.insert("variance".to_string(), Tensor::from_slice(&new_v, &v_dims)
301                        .map_err(|e| format!("Failed to create variance tensor: {:?}", e))?);
302                    *param = Tensor::from_slice(&new_p, &p_dims)
303                        .map_err(|e| format!("Failed to create param tensor: {:?}", e))?;
304                }
305            }
306        }
307        
308        Ok(())
309    }
310    
311    /// Offload to CPU
312    pub fn offload_to_cpu(&mut self, name: &str) -> Result<(), String> {
313        if !self.config.cpu_offload {
314            return Ok(());
315        }
316        
317        if let Some(param) = self.partitioned_params.get(name) {
318            let data = param.data_f32().to_vec();
319            self.cpu_buffer.insert(name.to_string(), data);
320            // In real implementation, would remove from GPU memory
321        }
322        
323        Ok(())
324    }
325    
326    /// Load from CPU
327    pub fn load_from_cpu(&mut self, name: &str) -> Result<(), String> {
328        if !self.config.cpu_offload {
329            return Ok(());
330        }
331        
332        if let Some(data) = self.cpu_buffer.get(name) {
333            if let Some(partition) = self.param_partitions.iter().find(|p| p.name == name) {
334                let tensor = Tensor::from_slice(data, &partition.shape)
335                    .map_err(|e| format!("Failed to load from CPU: {:?}", e))?;
336                self.partitioned_params.insert(name.to_string(), tensor);
337            }
338        }
339        
340        Ok(())
341    }
342    
343    /// Get memory savings ratio
344    pub fn memory_savings_ratio(&self) -> f32 {
345        match self.config.stage {
346            ZeRoStage::Stage1 => {
347                // Stage 1: Only optimizer states partitioned (saves ~4x memory for Adam)
348                // Memory = params + grads + (optimizer_states / N)
349                // Savings = (4 - 4/N) / 4 = 1 - 1/N
350                let n = self.config.world_size as f32;
351                (n - 1.0) / n * 0.5  // ~50% of total for optimizer states
352            }
353            ZeRoStage::Stage2 => {
354                // Stage 2: Optimizer states + gradients partitioned
355                // Memory = params + (grads + optimizer_states) / N
356                // Savings = (4 - 1 - 3/N) / 4
357                let n = self.config.world_size as f32;
358                (n - 1.0) / n * 0.75  // ~75% of total for grads + optimizer states
359            }
360            ZeRoStage::Stage3 => {
361                // Stage 3: Everything partitioned
362                // Memory = (params + grads + optimizer_states) / N
363                // Savings = (4 - 4/N) / 4 = (N-1)/N
364                let n = self.config.world_size as f32;
365                (n - 1.0) / n
366            }
367        }
368    }
369    
370    /// Get statistics
371    pub fn get_stats(&self) -> ZeRoStats {
372        let total_params: usize = self.partitioned_params.values()
373            .map(|t| t.data_f32().len())
374            .sum();
375        
376        let total_grads: usize = self.partitioned_grads.values()
377            .map(|t| t.data_f32().len())
378            .sum();
379        
380        ZeRoStats {
381            stage: self.config.stage,
382            world_size: self.config.world_size,
383            rank: self.config.rank,
384            num_partitioned_params: self.partitioned_params.len(),
385            total_param_elements: total_params,
386            total_grad_elements: total_grads,
387            memory_savings: self.memory_savings_ratio(),
388            cpu_offload_enabled: self.config.cpu_offload,
389        }
390    }
391}
392
393/// ZeRO statistics
394#[derive(Debug, Clone)]
395pub struct ZeRoStats {
396    pub stage: ZeRoStage,
397    pub world_size: usize,
398    pub rank: usize,
399    pub num_partitioned_params: usize,
400    pub total_param_elements: usize,
401    pub total_grad_elements: usize,
402    pub memory_savings: f32,
403    pub cpu_offload_enabled: bool,
404}
405
406#[cfg(test)]
407mod tests {
408    use super::*;
409    
410    #[test]
411    fn test_zero_config() {
412        let config = ZeRoConfig::default();
413        assert_eq!(config.stage, ZeRoStage::Stage2);
414        assert_eq!(config.world_size, 1);
415        
416        let stage3 = ZeRoConfig::stage3(4, 0);
417        assert_eq!(stage3.stage, ZeRoStage::Stage3);
418        assert_eq!(stage3.world_size, 4);
419    }
420    
421    #[test]
422    fn test_zero_optimizer_creation() {
423        let config = ZeRoConfig::stage2(4, 0);
424        let optimizer = ZeRoOptimizer::new(config, 0.001);
425        
426        let stats = optimizer.get_stats();
427        assert_eq!(stats.world_size, 4);
428        assert_eq!(stats.rank, 0);
429    }
430    
431    #[test]
432    fn test_partition_parameters() {
433        let config = ZeRoConfig::stage2(2, 0);
434        let mut optimizer = ZeRoOptimizer::new(config, 0.001);
435        
436        let mut params = HashMap::new();
437        params.insert("layer1".to_string(), Tensor::randn(&[10, 10]));
438        params.insert("layer2".to_string(), Tensor::randn(&[20, 20]));
439        
440        optimizer.partition_parameters(&params).unwrap();
441        assert!(optimizer.param_partitions.len() > 0);
442    }
443    
444    #[test]
445    fn test_memory_savings_ratio() {
446        let config1 = ZeRoConfig::stage1(4, 0);
447        let optimizer1 = ZeRoOptimizer::new(config1, 0.001);
448        let savings1 = optimizer1.memory_savings_ratio();
449        
450        let config2 = ZeRoConfig::stage2(4, 0);
451        let optimizer2 = ZeRoOptimizer::new(config2, 0.001);
452        let savings2 = optimizer2.memory_savings_ratio();
453        
454        let config3 = ZeRoConfig::stage3(4, 0);
455        let optimizer3 = ZeRoOptimizer::new(config3, 0.001);
456        let savings3 = optimizer3.memory_savings_ratio();
457        
458        // Stage 3 should have highest savings
459        assert!(savings3 > savings2);
460        assert!(savings2 > savings1);
461    }
462    
463    #[test]
464    fn test_offload_config() {
465        let config = ZeRoConfig::stage3(4, 0)
466            .with_offload(true, false);
467        
468        assert!(config.cpu_offload);
469        assert!(!config.nvme_offload);
470    }
471    
472    #[test]
473    fn test_cpu_offload() {
474        let config = ZeRoConfig::stage2(2, 0).with_offload(true, false);
475        let mut optimizer = ZeRoOptimizer::new(config, 0.001);
476        
477        let mut params = HashMap::new();
478        params.insert("layer1".to_string(), Tensor::randn(&[5, 5]));
479        
480        optimizer.partition_parameters(&params).unwrap();
481        optimizer.offload_to_cpu("layer1").unwrap();
482        
483        assert!(optimizer.cpu_buffer.contains_key("layer1"));
484    }
485}