ghostflow_ml/
distributed.rs

1//! Distributed Training - Data Parallelism and Model Parallelism
2//!
3//! This module provides utilities for distributed training across multiple nodes/processes.
4
5use std::sync::{Arc, Mutex};
6use ghostflow_core::Tensor;
7
8/// Distributed training strategy
9#[derive(Clone, Copy, Debug)]
10pub enum DistributedStrategy {
11    /// Data parallelism - split data across workers
12    DataParallel,
13    /// Model parallelism - split model across workers
14    ModelParallel,
15    /// Hybrid approach
16    Hybrid,
17}
18
19/// Communication backend for distributed training
20#[derive(Clone, Copy, Debug)]
21pub enum CommunicationBackend {
22    /// Thread-based (single machine)
23    Threads,
24    /// Process-based (single or multiple machines)
25    Processes,
26    /// MPI-based (multiple machines)
27    MPI,
28}
29
30/// Gradient aggregation method
31#[derive(Clone, Copy, Debug)]
32pub enum GradientAggregation {
33    /// Average gradients across workers
34    Average,
35    /// Sum gradients across workers
36    Sum,
37    /// Weighted average
38    WeightedAverage,
39}
40
41/// Distributed trainer configuration
42pub struct DistributedConfig {
43    pub strategy: DistributedStrategy,
44    pub backend: CommunicationBackend,
45    pub world_size: usize,
46    pub rank: usize,
47    pub gradient_aggregation: GradientAggregation,
48    pub sync_frequency: usize,
49}
50
51impl DistributedConfig {
52    pub fn new(world_size: usize, rank: usize) -> Self {
53        DistributedConfig {
54            strategy: DistributedStrategy::DataParallel,
55            backend: CommunicationBackend::Threads,
56            world_size,
57            rank,
58            gradient_aggregation: GradientAggregation::Average,
59            sync_frequency: 1,
60        }
61    }
62
63    pub fn strategy(mut self, strategy: DistributedStrategy) -> Self {
64        self.strategy = strategy;
65        self
66    }
67
68    pub fn backend(mut self, backend: CommunicationBackend) -> Self {
69        self.backend = backend;
70        self
71    }
72
73    pub fn gradient_aggregation(mut self, agg: GradientAggregation) -> Self {
74        self.gradient_aggregation = agg;
75        self
76    }
77
78    pub fn sync_frequency(mut self, freq: usize) -> Self {
79        self.sync_frequency = freq;
80        self
81    }
82}
83
84/// Data parallel trainer
85pub struct DataParallelTrainer {
86    config: DistributedConfig,
87    local_gradients: Arc<Mutex<Vec<Vec<f32>>>>,
88    global_gradients: Arc<Mutex<Vec<Vec<f32>>>>,
89    iteration: usize,
90}
91
92impl DataParallelTrainer {
93    pub fn new(config: DistributedConfig) -> Self {
94        DataParallelTrainer {
95            config,
96            local_gradients: Arc::new(Mutex::new(Vec::new())),
97            global_gradients: Arc::new(Mutex::new(Vec::new())),
98            iteration: 0,
99        }
100    }
101
102    /// Split data across workers
103    pub fn split_data(&self, data: &Tensor, labels: &Tensor) -> (Tensor, Tensor) {
104        let n_samples = data.dims()[0];
105        let n_features = data.dims()[1];
106        let samples_per_worker = n_samples / self.config.world_size;
107        
108        let start_idx = self.config.rank * samples_per_worker;
109        let end_idx = if self.config.rank == self.config.world_size - 1 {
110            n_samples
111        } else {
112            (self.config.rank + 1) * samples_per_worker
113        };
114
115        let data_slice = &data.data_f32()[start_idx * n_features..end_idx * n_features];
116        let labels_slice = &labels.data_f32()[start_idx..end_idx];
117
118        let local_data = Tensor::from_slice(data_slice, &[end_idx - start_idx, n_features]).unwrap();
119        let local_labels = Tensor::from_slice(labels_slice, &[end_idx - start_idx]).unwrap();
120
121        (local_data, local_labels)
122    }
123
124    /// Accumulate local gradients
125    pub fn accumulate_gradients(&mut self, gradients: Vec<Vec<f32>>) {
126        let mut local_grads = self.local_gradients.lock().unwrap();
127        *local_grads = gradients;
128    }
129
130    /// Synchronize gradients across workers
131    pub fn sync_gradients(&mut self) -> Vec<Vec<f32>> {
132        self.iteration += 1;
133
134        // Only sync at specified frequency
135        if self.iteration % self.config.sync_frequency != 0 {
136            return self.local_gradients.lock().unwrap().clone();
137        }
138
139        match self.config.backend {
140            CommunicationBackend::Threads => self.sync_gradients_threads(),
141            CommunicationBackend::Processes => self.sync_gradients_processes(),
142            CommunicationBackend::MPI => self.sync_gradients_mpi(),
143        }
144    }
145
146    fn sync_gradients_threads(&self) -> Vec<Vec<f32>> {
147        // Simplified thread-based synchronization
148        let local_grads = self.local_gradients.lock().unwrap();
149        let mut global_grads = self.global_gradients.lock().unwrap();
150
151        if global_grads.is_empty() {
152            *global_grads = local_grads.clone();
153        } else {
154            // Aggregate gradients
155            for (global_layer, local_layer) in global_grads.iter_mut().zip(local_grads.iter()) {
156                for (g, l) in global_layer.iter_mut().zip(local_layer.iter()) {
157                    match self.config.gradient_aggregation {
158                        GradientAggregation::Average => {
159                            *g = (*g * (self.config.world_size - 1) as f32 + l) / self.config.world_size as f32;
160                        }
161                        GradientAggregation::Sum => {
162                            *g += l;
163                        }
164                        GradientAggregation::WeightedAverage => {
165                            *g = (*g + l) / 2.0;
166                        }
167                    }
168                }
169            }
170        }
171
172        global_grads.clone()
173    }
174
175    fn sync_gradients_processes(&self) -> Vec<Vec<f32>> {
176        // Placeholder for process-based synchronization
177        // In a real implementation, this would use IPC mechanisms
178        self.local_gradients.lock().unwrap().clone()
179    }
180
181    fn sync_gradients_mpi(&self) -> Vec<Vec<f32>> {
182        // Placeholder for MPI-based synchronization
183        // In a real implementation, this would use MPI libraries
184        self.local_gradients.lock().unwrap().clone()
185    }
186
187    /// All-reduce operation for gradients
188    pub fn all_reduce(&self, gradients: &[Vec<f32>]) -> Vec<Vec<f32>> {
189        // Simplified all-reduce
190        let mut reduced = gradients.to_vec();
191
192        match self.config.gradient_aggregation {
193            GradientAggregation::Average => {
194                for layer in &mut reduced {
195                    for grad in layer {
196                        *grad /= self.config.world_size as f32;
197                    }
198                }
199            }
200            GradientAggregation::Sum => {
201                // Already summed
202            }
203            GradientAggregation::WeightedAverage => {
204                for layer in &mut reduced {
205                    for grad in layer {
206                        *grad /= self.config.world_size as f32;
207                    }
208                }
209            }
210        }
211
212        reduced
213    }
214
215    /// Broadcast parameters from rank 0 to all workers
216    pub fn broadcast_parameters(&self, parameters: &[Vec<f32>]) -> Vec<Vec<f32>> {
217        if self.config.rank == 0 {
218            // Master broadcasts
219            parameters.to_vec()
220        } else {
221            // Workers receive
222            // In a real implementation, this would receive from master
223            parameters.to_vec()
224        }
225    }
226
227    pub fn is_master(&self) -> bool {
228        self.config.rank == 0
229    }
230}
231
232/// Distributed data loader
233pub struct DistributedDataLoader {
234    pub batch_size: usize,
235    pub world_size: usize,
236    pub rank: usize,
237    pub shuffle: bool,
238    pub drop_last: bool,
239}
240
241impl DistributedDataLoader {
242    pub fn new(batch_size: usize, world_size: usize, rank: usize) -> Self {
243        DistributedDataLoader {
244            batch_size,
245            world_size,
246            rank,
247            shuffle: true,
248            drop_last: false,
249        }
250    }
251
252    pub fn shuffle(mut self, shuffle: bool) -> Self {
253        self.shuffle = shuffle;
254        self
255    }
256
257    pub fn drop_last(mut self, drop: bool) -> Self {
258        self.drop_last = drop;
259        self
260    }
261
262    /// Get batches for this worker
263    pub fn get_batches(&self, data: &Tensor, labels: &Tensor) -> Vec<(Tensor, Tensor)> {
264        let n_samples = data.dims()[0];
265        let n_features = data.dims()[1];
266        
267        // Calculate samples per worker
268        let samples_per_worker = n_samples / self.world_size;
269        let start_idx = self.rank * samples_per_worker;
270        let end_idx = if self.rank == self.world_size - 1 {
271            n_samples
272        } else {
273            (self.rank + 1) * samples_per_worker
274        };
275
276        let worker_samples = end_idx - start_idx;
277        let n_batches = if self.drop_last {
278            worker_samples / self.batch_size
279        } else {
280            (worker_samples + self.batch_size - 1) / self.batch_size
281        };
282
283        let mut batches = Vec::new();
284        let data_slice = data.data_f32();
285        let labels_slice = labels.data_f32();
286
287        for batch_idx in 0..n_batches {
288            let batch_start = start_idx + batch_idx * self.batch_size;
289            let batch_end = (batch_start + self.batch_size).min(end_idx);
290            let batch_size = batch_end - batch_start;
291
292            let batch_data: Vec<f32> = (batch_start..batch_end)
293                .flat_map(|i| data_slice[i * n_features..(i + 1) * n_features].to_vec())
294                .collect();
295            let batch_labels: Vec<f32> = (batch_start..batch_end)
296                .map(|i| labels_slice[i])
297                .collect();
298
299            let data_tensor = Tensor::from_slice(&batch_data, &[batch_size, n_features]).unwrap();
300            let labels_tensor = Tensor::from_slice(&batch_labels, &[batch_size]).unwrap();
301
302            batches.push((data_tensor, labels_tensor));
303        }
304
305        batches
306    }
307}
308
309/// Gradient compression for communication efficiency
310pub struct GradientCompression {
311    pub method: CompressionMethod,
312    pub compression_ratio: f32,
313}
314
315#[derive(Clone, Copy, Debug)]
316pub enum CompressionMethod {
317    /// No compression
318    None,
319    /// Top-K sparsification
320    TopK,
321    /// Random sparsification
322    Random,
323    /// Quantization
324    Quantization,
325}
326
327impl GradientCompression {
328    pub fn new(method: CompressionMethod) -> Self {
329        GradientCompression {
330            method,
331            compression_ratio: 0.1,
332        }
333    }
334
335    pub fn compression_ratio(mut self, ratio: f32) -> Self {
336        self.compression_ratio = ratio;
337        self
338    }
339
340    pub fn compress(&self, gradients: &[f32]) -> (Vec<usize>, Vec<f32>) {
341        match self.method {
342            CompressionMethod::None => {
343                let indices: Vec<usize> = (0..gradients.len()).collect();
344                (indices, gradients.to_vec())
345            }
346            CompressionMethod::TopK => self.compress_topk(gradients),
347            CompressionMethod::Random => self.compress_random(gradients),
348            CompressionMethod::Quantization => self.compress_quantize(gradients),
349        }
350    }
351
352    fn compress_topk(&self, gradients: &[f32]) -> (Vec<usize>, Vec<f32>) {
353        let k = (gradients.len() as f32 * self.compression_ratio) as usize;
354        let mut indexed: Vec<(usize, f32)> = gradients.iter()
355            .enumerate()
356            .map(|(i, &g)| (i, g.abs()))
357            .collect();
358
359        indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
360        indexed.truncate(k);
361
362        let indices: Vec<usize> = indexed.iter().map(|(i, _)| *i).collect();
363        let values: Vec<f32> = indexed.iter().map(|(i, _)| gradients[*i]).collect();
364
365        (indices, values)
366    }
367
368    fn compress_random(&self, gradients: &[f32]) -> (Vec<usize>, Vec<f32>) {
369        use rand::prelude::*;
370        let mut rng = thread_rng();
371        let k = (gradients.len() as f32 * self.compression_ratio) as usize;
372
373        let mut indices: Vec<usize> = (0..gradients.len()).collect();
374        indices.shuffle(&mut rng);
375        indices.truncate(k);
376
377        let values: Vec<f32> = indices.iter().map(|&i| gradients[i]).collect();
378
379        (indices, values)
380    }
381
382    fn compress_quantize(&self, gradients: &[f32]) -> (Vec<usize>, Vec<f32>) {
383        // Simple 8-bit quantization
384        let max_abs = gradients.iter().map(|g| g.abs()).fold(0.0f32, f32::max);
385        let scale = max_abs / 127.0;
386
387        let quantized: Vec<f32> = gradients.iter()
388            .map(|&g| (g / scale).round() * scale)
389            .collect();
390
391        let indices: Vec<usize> = (0..gradients.len()).collect();
392        (indices, quantized)
393    }
394
395    pub fn decompress(&self, indices: &[usize], values: &[f32], size: usize) -> Vec<f32> {
396        let mut decompressed = vec![0.0f32; size];
397        for (&idx, &val) in indices.iter().zip(values.iter()) {
398            if idx < size {
399                decompressed[idx] = val;
400            }
401        }
402        decompressed
403    }
404}
405
406/// Ring All-Reduce implementation
407pub struct RingAllReduce {
408    pub world_size: usize,
409    pub rank: usize,
410}
411
412impl RingAllReduce {
413    pub fn new(world_size: usize, rank: usize) -> Self {
414        RingAllReduce { world_size, rank }
415    }
416
417    /// Perform ring all-reduce on gradients
418    pub fn all_reduce(&self, gradients: &[Vec<f32>]) -> Vec<Vec<f32>> {
419        // Simplified ring all-reduce
420        // In a real implementation, this would communicate in a ring topology
421        let mut result = gradients.to_vec();
422
423        // Simulate ring communication
424        for layer in &mut result {
425            let sum: f32 = layer.iter().sum();
426            let avg = sum / self.world_size as f32;
427            for grad in layer {
428                *grad = avg;
429            }
430        }
431
432        result
433    }
434
435    #[allow(dead_code)]
436    fn get_next_rank(&self) -> usize {
437        (self.rank + 1) % self.world_size
438    }
439
440    #[allow(dead_code)]
441    fn get_prev_rank(&self) -> usize {
442        (self.rank + self.world_size - 1) % self.world_size
443    }
444}
445/// Distributed optimizer wrapper
446pub struct DistributedOptimizer<O> {
447    #[allow(dead_code)]
448    optimizer: O,
449    trainer: DataParallelTrainer,
450    compression: Option<GradientCompression>,
451}
452
453impl<O> DistributedOptimizer<O> {
454    pub fn new(optimizer: O, config: DistributedConfig) -> Self {
455        DistributedOptimizer {
456            optimizer,
457            trainer: DataParallelTrainer::new(config),
458            compression: None,
459        }
460    }
461
462    pub fn with_compression(mut self, compression: GradientCompression) -> Self {
463        self.compression = Some(compression);
464        self
465    }
466
467    pub fn step(&mut self, params: &mut [f32], local_grads: &[f32]) {
468        // Compress gradients if enabled
469        let grads_to_sync = if let Some(ref compression) = self.compression {
470            let (indices, values) = compression.compress(local_grads);
471            compression.decompress(&indices, &values, local_grads.len())
472        } else {
473            local_grads.to_vec()
474        };
475
476        // Synchronize gradients
477        let grad_vec = vec![grads_to_sync];
478        self.trainer.accumulate_gradients(grad_vec);
479        let synced_grads = self.trainer.sync_gradients();
480
481        // Apply optimizer step with synchronized gradients
482        if !synced_grads.is_empty() && !synced_grads[0].is_empty() {
483            // In a real implementation, this would call the optimizer's step method
484            for (p, g) in params.iter_mut().zip(synced_grads[0].iter()) {
485                *p -= 0.01 * g; // Simplified update
486            }
487        }
488    }
489
490    pub fn is_master(&self) -> bool {
491        self.trainer.is_master()
492    }
493}
494
495#[cfg(test)]
496mod tests {
497    use super::*;
498
499    #[test]
500    fn test_data_parallel_trainer() {
501        let config = DistributedConfig::new(2, 0);
502        let trainer = DataParallelTrainer::new(config);
503
504        let data = Tensor::from_slice(&vec![1.0f32; 100], &[10, 10]).unwrap();
505        let labels = Tensor::from_slice(&vec![0.0f32; 10], &[10]).unwrap();
506
507        let (local_data, _local_labels) = trainer.split_data(&data, &labels);
508        assert_eq!(local_data.dims()[0], 5); // Half the data
509    }
510
511    #[test]
512    fn test_distributed_data_loader() {
513        let loader = DistributedDataLoader::new(2, 2, 0);
514        
515        let data = Tensor::from_slice(&vec![1.0f32; 100], &[10, 10]).unwrap();
516        let labels = Tensor::from_slice(&vec![0.0f32; 10], &[10]).unwrap();
517
518        let batches = loader.get_batches(&data, &labels);
519        assert!(batches.len() > 0);
520    }
521
522    #[test]
523    fn test_gradient_compression() {
524        let compression = GradientCompression::new(CompressionMethod::TopK)
525            .compression_ratio(0.5);
526
527        let gradients = vec![1.0, 2.0, 3.0, 4.0, 5.0];
528        let (indices, values) = compression.compress(&gradients);
529
530        assert!(indices.len() <= (gradients.len() as f32 * 0.5) as usize + 1);
531        
532        let decompressed = compression.decompress(&indices, &values, gradients.len());
533        assert_eq!(decompressed.len(), gradients.len());
534    }
535
536    #[test]
537    fn test_ring_all_reduce() {
538        let ring = RingAllReduce::new(4, 0);
539        let gradients = vec![vec![1.0, 2.0, 3.0]];
540        
541        let result = ring.all_reduce(&gradients);
542        assert_eq!(result.len(), 1);
543        assert_eq!(result[0].len(), 3);
544    }
545}