ghostflow_nn/
distributed.rs

1//! Distributed Training
2//!
3//! Support for multi-GPU training with data and model parallelism.
4
5use ghostflow_core::tensor::Tensor;
6use std::sync::{Arc, Mutex};
7use std::collections::HashMap;
8
9/// Distributed training backend
10#[derive(Clone, Copy, Debug, PartialEq)]
11pub enum DistributedBackend {
12    /// NVIDIA NCCL for GPU communication
13    NCCL,
14    /// Gloo for CPU/GPU communication
15    Gloo,
16    /// MPI for HPC environments
17    MPI,
18}
19
20/// Distributed training configuration
21#[derive(Clone, Debug)]
22pub struct DistributedConfig {
23    /// Backend to use
24    pub backend: DistributedBackend,
25    /// World size (total number of processes)
26    pub world_size: usize,
27    /// Rank of this process
28    pub rank: usize,
29    /// Master address for coordination
30    pub master_addr: String,
31    /// Master port
32    pub master_port: u16,
33}
34
35impl Default for DistributedConfig {
36    fn default() -> Self {
37        Self {
38            backend: DistributedBackend::NCCL,
39            world_size: 1,
40            rank: 0,
41            master_addr: "localhost".to_string(),
42            master_port: 29500,
43        }
44    }
45}
46
47/// Data parallel training
48/// 
49/// Replicates the model across GPUs and splits data batches.
50/// Gradients are averaged across all GPUs after backward pass.
51pub struct DataParallel {
52    config: DistributedConfig,
53    device_ids: Vec<usize>,
54    gradient_buckets: Arc<Mutex<HashMap<String, Vec<Tensor>>>>,
55}
56
57impl DataParallel {
58    pub fn new(config: DistributedConfig, device_ids: Vec<usize>) -> Self {
59        Self {
60            config,
61            device_ids,
62            gradient_buckets: Arc::new(Mutex::new(HashMap::new())),
63        }
64    }
65
66    /// Split batch across devices
67    pub fn split_batch(&self, batch: &Tensor) -> Vec<Tensor> {
68        let batch_size = batch.shape().dims()[0];
69        let per_device = batch_size / self.device_ids.len();
70        
71        let mut splits = Vec::new();
72        for i in 0..self.device_ids.len() {
73            let start = i * per_device;
74            let end = if i == self.device_ids.len() - 1 {
75                batch_size
76            } else {
77                (i + 1) * per_device
78            };
79            
80            // Create a view of the batch for this device
81            // In practice, this would also move to the specific GPU
82            let split = self.slice_batch(batch, start, end);
83            splits.push(split);
84        }
85        
86        splits
87    }
88
89    fn slice_batch(&self, batch: &Tensor, start: usize, end: usize) -> Tensor {
90        // Simplified batch slicing
91        // In practice, would use proper tensor slicing
92        let batch_data = batch.storage().as_slice::<f32>();
93        let dims = batch.shape().dims();
94        let row_size = dims[1..].iter().product::<usize>();
95        
96        let slice_data: Vec<f32> = batch_data[start * row_size..end * row_size].to_vec();
97        let mut new_dims = dims.to_vec();
98        new_dims[0] = end - start;
99        
100        Tensor::from_slice(&slice_data, &new_dims).unwrap()
101    }
102
103    /// All-reduce gradients across devices
104    pub fn all_reduce_gradients(&self, gradients: &HashMap<String, Tensor>) -> HashMap<String, Tensor> {
105        let mut averaged = HashMap::new();
106        
107        for (name, grad) in gradients {
108            // Simulate all-reduce by averaging
109            // In practice, this would use NCCL/Gloo for actual GPU communication
110            let averaged_grad = self.average_gradient(grad);
111            averaged.insert(name.clone(), averaged_grad);
112        }
113        
114        averaged
115    }
116
117    fn average_gradient(&self, grad: &Tensor) -> Tensor {
118        // Simulate averaging across devices
119        // In practice, would use collective communication
120        let scale = 1.0 / self.config.world_size as f32;
121        
122        let grad_data = grad.storage().as_slice::<f32>();
123        let scaled_data: Vec<f32> = grad_data.iter().map(|&x| x * scale).collect();
124        
125        Tensor::from_slice(&scaled_data, grad.shape().dims()).unwrap()
126    }
127
128    /// Broadcast model parameters from rank 0 to all ranks
129    pub fn broadcast_parameters(&self, parameters: &HashMap<String, Tensor>) -> HashMap<String, Tensor> {
130        if self.config.rank == 0 {
131            // Rank 0 keeps its parameters
132            parameters.clone()
133        } else {
134            // Other ranks would receive from rank 0
135            // In practice, this would use actual broadcast communication
136            parameters.clone()
137        }
138    }
139}
140
141/// Model parallel training
142/// 
143/// Splits the model across multiple GPUs.
144/// Different layers run on different devices.
145pub struct ModelParallel {
146    config: DistributedConfig,
147    layer_placement: HashMap<String, usize>,
148}
149
150impl ModelParallel {
151    pub fn new(config: DistributedConfig) -> Self {
152        Self {
153            config,
154            layer_placement: HashMap::new(),
155        }
156    }
157
158    /// Assign a layer to a specific device
159    pub fn place_layer(&mut self, layer_name: &str, device_id: usize) {
160        self.layer_placement.insert(layer_name.to_string(), device_id);
161    }
162
163    /// Get device for a layer
164    pub fn get_device(&self, layer_name: &str) -> Option<usize> {
165        self.layer_placement.get(layer_name).copied()
166    }
167
168    /// Automatic layer placement using a simple strategy
169    pub fn auto_place_layers(&mut self, layer_names: &[String], num_devices: usize) {
170        let layers_per_device = (layer_names.len() + num_devices - 1) / num_devices;
171        
172        for (i, name) in layer_names.iter().enumerate() {
173            let device = i / layers_per_device;
174            self.place_layer(name, device.min(num_devices - 1));
175        }
176    }
177
178    /// Transfer tensor between devices
179    pub fn transfer(&self, tensor: &Tensor, _from_device: usize, _to_device: usize) -> Tensor {
180        // In practice, this would perform actual GPU-to-GPU transfer
181        // For now, just clone the tensor
182        tensor.clone()
183    }
184}
185
186/// Gradient accumulation
187/// 
188/// Accumulates gradients over multiple micro-batches before updating.
189/// Useful for training with large effective batch sizes on limited memory.
190pub struct GradientAccumulator {
191    accumulation_steps: usize,
192    current_step: usize,
193    accumulated_gradients: HashMap<String, Tensor>,
194}
195
196impl GradientAccumulator {
197    pub fn new(accumulation_steps: usize) -> Self {
198        Self {
199            accumulation_steps,
200            current_step: 0,
201            accumulated_gradients: HashMap::new(),
202        }
203    }
204
205    /// Accumulate gradients from a micro-batch
206    pub fn accumulate(&mut self, gradients: &HashMap<String, Tensor>) {
207        for (name, grad) in gradients {
208            let should_add = self.accumulated_gradients.contains_key(name);
209            
210            if should_add {
211                // Add to existing accumulated gradient
212                let accumulated = self.accumulated_gradients.get(name).unwrap();
213                let sum = self.add_tensors(accumulated, grad);
214                self.accumulated_gradients.insert(name.clone(), sum);
215            } else {
216                // First accumulation
217                self.accumulated_gradients.insert(name.clone(), grad.clone());
218            }
219        }
220        
221        self.current_step += 1;
222    }
223
224    fn add_tensors(&self, a: &Tensor, b: &Tensor) -> Tensor {
225        let a_data = a.storage().as_slice::<f32>();
226        let b_data = b.storage().as_slice::<f32>();
227        
228        let sum: Vec<f32> = a_data.iter().zip(b_data.iter())
229            .map(|(x, y)| x + y)
230            .collect();
231        
232        Tensor::from_slice(&sum, a.shape().dims()).unwrap()
233    }
234
235    /// Check if ready to update (accumulated enough steps)
236    pub fn should_update(&self) -> bool {
237        self.current_step >= self.accumulation_steps
238    }
239
240    /// Get accumulated gradients and reset
241    pub fn get_and_reset(&mut self) -> HashMap<String, Tensor> {
242        let gradients = self.accumulated_gradients.clone();
243        self.accumulated_gradients.clear();
244        self.current_step = 0;
245        
246        // Scale by accumulation steps
247        let scale = 1.0 / self.accumulation_steps as f32;
248        gradients.into_iter()
249            .map(|(name, grad)| {
250                let scaled = self.scale_tensor(&grad, scale);
251                (name, scaled)
252            })
253            .collect()
254    }
255
256    fn scale_tensor(&self, tensor: &Tensor, scale: f32) -> Tensor {
257        let data = tensor.storage().as_slice::<f32>();
258        let scaled: Vec<f32> = data.iter().map(|&x| x * scale).collect();
259        Tensor::from_slice(&scaled, tensor.shape().dims()).unwrap()
260    }
261
262    /// Reset accumulator
263    pub fn reset(&mut self) {
264        self.accumulated_gradients.clear();
265        self.current_step = 0;
266    }
267}
268
269/// Distributed Data Parallel (DDP)
270/// 
271/// Combines data parallelism with efficient gradient synchronization.
272/// Overlaps communication with computation for better performance.
273pub struct DistributedDataParallel {
274    data_parallel: DataParallel,
275    gradient_accumulator: Option<GradientAccumulator>,
276    find_unused_parameters: bool,
277}
278
279impl DistributedDataParallel {
280    pub fn new(
281        config: DistributedConfig,
282        device_ids: Vec<usize>,
283        gradient_accumulation_steps: Option<usize>,
284    ) -> Self {
285        let gradient_accumulator = gradient_accumulation_steps
286            .map(GradientAccumulator::new);
287        
288        Self {
289            data_parallel: DataParallel::new(config, device_ids),
290            gradient_accumulator,
291            find_unused_parameters: false,
292        }
293    }
294
295    /// Enable finding unused parameters
296    pub fn find_unused_parameters(mut self, enabled: bool) -> Self {
297        self.find_unused_parameters = enabled;
298        self
299    }
300
301    /// Forward pass with data parallelism
302    pub fn forward(&self, batch: &Tensor) -> Vec<Tensor> {
303        self.data_parallel.split_batch(batch)
304    }
305
306    /// Backward pass with gradient synchronization
307    pub fn backward(&mut self, gradients: &HashMap<String, Tensor>) -> Option<HashMap<String, Tensor>> {
308        // All-reduce gradients across devices
309        let reduced_gradients = self.data_parallel.all_reduce_gradients(gradients);
310        
311        // Handle gradient accumulation if enabled
312        if let Some(ref mut accumulator) = self.gradient_accumulator {
313            accumulator.accumulate(&reduced_gradients);
314            
315            if accumulator.should_update() {
316                Some(accumulator.get_and_reset())
317            } else {
318                None
319            }
320        } else {
321            Some(reduced_gradients)
322        }
323    }
324
325    /// Synchronize parameters across all processes
326    pub fn sync_parameters(&self, parameters: &HashMap<String, Tensor>) -> HashMap<String, Tensor> {
327        self.data_parallel.broadcast_parameters(parameters)
328    }
329}
330
331/// Pipeline parallelism
332/// 
333/// Splits model into stages and processes micro-batches in a pipeline.
334pub struct PipelineParallel {
335    num_stages: usize,
336    num_micro_batches: usize,
337    current_stage: usize,
338}
339
340impl PipelineParallel {
341    pub fn new(num_stages: usize, num_micro_batches: usize) -> Self {
342        Self {
343            num_stages,
344            num_micro_batches,
345            current_stage: 0,
346        }
347    }
348
349    /// Split batch into micro-batches
350    pub fn create_micro_batches(&self, batch: &Tensor) -> Vec<Tensor> {
351        let batch_size = batch.shape().dims()[0];
352        let micro_batch_size = batch_size / self.num_micro_batches;
353        
354        let mut micro_batches = Vec::new();
355        for i in 0..self.num_micro_batches {
356            let start = i * micro_batch_size;
357            let end = if i == self.num_micro_batches - 1 {
358                batch_size
359            } else {
360                (i + 1) * micro_batch_size
361            };
362            
363            let micro_batch = self.slice_batch(batch, start, end);
364            micro_batches.push(micro_batch);
365        }
366        
367        micro_batches
368    }
369
370    fn slice_batch(&self, batch: &Tensor, start: usize, end: usize) -> Tensor {
371        let batch_data = batch.storage().as_slice::<f32>();
372        let dims = batch.shape().dims();
373        let row_size = dims[1..].iter().product::<usize>();
374        
375        let slice_data: Vec<f32> = batch_data[start * row_size..end * row_size].to_vec();
376        let mut new_dims = dims.to_vec();
377        new_dims[0] = end - start;
378        
379        Tensor::from_slice(&slice_data, &new_dims).unwrap()
380    }
381
382    /// Get current pipeline stage
383    pub fn current_stage(&self) -> usize {
384        self.current_stage
385    }
386
387    /// Advance to next stage
388    pub fn next_stage(&mut self) {
389        self.current_stage = (self.current_stage + 1) % self.num_stages;
390    }
391}
392
393#[cfg(test)]
394mod tests {
395    use super::*;
396
397    #[test]
398    fn test_data_parallel_split_batch() {
399        let config = DistributedConfig {
400            world_size: 2,
401            rank: 0,
402            ..Default::default()
403        };
404        let dp = DataParallel::new(config, vec![0, 1]);
405
406        let batch = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[4, 2]).unwrap();
407        let splits = dp.split_batch(&batch);
408
409        assert_eq!(splits.len(), 2);
410        assert_eq!(splits[0].shape().dims()[0], 2);
411        assert_eq!(splits[1].shape().dims()[0], 2);
412    }
413
414    #[test]
415    fn test_gradient_accumulation() {
416        let mut accumulator = GradientAccumulator::new(4);
417
418        let grad1 = Tensor::from_slice(&[1.0f32, 2.0, 3.0], &[3]).unwrap();
419        let grad2 = Tensor::from_slice(&[2.0f32, 3.0, 4.0], &[3]).unwrap();
420
421        let mut grads = HashMap::new();
422        grads.insert("layer1".to_string(), grad1);
423
424        accumulator.accumulate(&grads);
425        assert!(!accumulator.should_update());
426
427        accumulator.accumulate(&grads);
428        accumulator.accumulate(&grads);
429        accumulator.accumulate(&grads);
430        assert!(accumulator.should_update());
431
432        let final_grads = accumulator.get_and_reset();
433        assert!(final_grads.contains_key("layer1"));
434        assert_eq!(accumulator.current_step, 0);
435    }
436
437    #[test]
438    fn test_model_parallel_placement() {
439        let config = DistributedConfig::default();
440        let mut mp = ModelParallel::new(config);
441
442        mp.place_layer("layer1", 0);
443        mp.place_layer("layer2", 1);
444        mp.place_layer("layer3", 0);
445
446        assert_eq!(mp.get_device("layer1"), Some(0));
447        assert_eq!(mp.get_device("layer2"), Some(1));
448        assert_eq!(mp.get_device("layer3"), Some(0));
449    }
450
451    #[test]
452    fn test_auto_layer_placement() {
453        let config = DistributedConfig::default();
454        let mut mp = ModelParallel::new(config);
455
456        let layers = vec![
457            "layer1".to_string(),
458            "layer2".to_string(),
459            "layer3".to_string(),
460            "layer4".to_string(),
461        ];
462
463        mp.auto_place_layers(&layers, 2);
464
465        assert_eq!(mp.get_device("layer1"), Some(0));
466        assert_eq!(mp.get_device("layer2"), Some(0));
467        assert_eq!(mp.get_device("layer3"), Some(1));
468        assert_eq!(mp.get_device("layer4"), Some(1));
469    }
470
471    #[test]
472    fn test_ddp_forward_backward() {
473        let config = DistributedConfig {
474            world_size: 2,
475            rank: 0,
476            ..Default::default()
477        };
478        let mut ddp = DistributedDataParallel::new(config, vec![0, 1], Some(2));
479
480        let batch = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
481        let splits = ddp.forward(&batch);
482        assert_eq!(splits.len(), 2);
483
484        let mut gradients = HashMap::new();
485        gradients.insert("layer1".to_string(), Tensor::from_slice(&[1.0f32, 2.0], &[2]).unwrap());
486
487        // First backward - should accumulate
488        let result = ddp.backward(&gradients);
489        assert!(result.is_none());
490
491        // Second backward - should return accumulated gradients
492        let result = ddp.backward(&gradients);
493        assert!(result.is_some());
494    }
495
496    #[test]
497    fn test_pipeline_parallel() {
498        let pp = PipelineParallel::new(4, 8);
499
500        let batch = Tensor::from_slice(&(0..32).map(|x| x as f32).collect::<Vec<_>>(), &[8, 4]).unwrap();
501        let micro_batches = pp.create_micro_batches(&batch);
502
503        assert_eq!(micro_batches.len(), 8);
504        assert_eq!(micro_batches[0].shape().dims()[0], 1);
505    }
506
507    #[test]
508    fn test_all_reduce_gradients() {
509        let config = DistributedConfig {
510            world_size: 4,
511            rank: 0,
512            ..Default::default()
513        };
514        let dp = DataParallel::new(config, vec![0, 1, 2, 3]);
515
516        let mut gradients = HashMap::new();
517        gradients.insert(
518            "layer1".to_string(),
519            Tensor::from_slice(&[4.0f32, 8.0, 12.0], &[3]).unwrap()
520        );
521
522        let reduced = dp.all_reduce_gradients(&gradients);
523        let grad_data = reduced.get("layer1").unwrap().storage().as_slice::<f32>();
524        
525        // Should be averaged across 4 devices
526        assert!((grad_data[0] - 1.0).abs() < 0.01);
527        assert!((grad_data[1] - 2.0).abs() < 0.01);
528        assert!((grad_data[2] - 3.0).abs() < 0.01);
529    }
530}