1use std::sync::{Arc, Mutex};
6use ghostflow_core::Tensor;
7
8#[derive(Clone, Copy, Debug)]
10pub enum DistributedStrategy {
11 DataParallel,
13 ModelParallel,
15 Hybrid,
17}
18
19#[derive(Clone, Copy, Debug)]
21pub enum CommunicationBackend {
22 Threads,
24 Processes,
26 MPI,
28}
29
30#[derive(Clone, Copy, Debug)]
32pub enum GradientAggregation {
33 Average,
35 Sum,
37 WeightedAverage,
39}
40
41pub struct DistributedConfig {
43 pub strategy: DistributedStrategy,
44 pub backend: CommunicationBackend,
45 pub world_size: usize,
46 pub rank: usize,
47 pub gradient_aggregation: GradientAggregation,
48 pub sync_frequency: usize,
49}
50
51impl DistributedConfig {
52 pub fn new(world_size: usize, rank: usize) -> Self {
53 DistributedConfig {
54 strategy: DistributedStrategy::DataParallel,
55 backend: CommunicationBackend::Threads,
56 world_size,
57 rank,
58 gradient_aggregation: GradientAggregation::Average,
59 sync_frequency: 1,
60 }
61 }
62
63 pub fn strategy(mut self, strategy: DistributedStrategy) -> Self {
64 self.strategy = strategy;
65 self
66 }
67
68 pub fn backend(mut self, backend: CommunicationBackend) -> Self {
69 self.backend = backend;
70 self
71 }
72
73 pub fn gradient_aggregation(mut self, agg: GradientAggregation) -> Self {
74 self.gradient_aggregation = agg;
75 self
76 }
77
78 pub fn sync_frequency(mut self, freq: usize) -> Self {
79 self.sync_frequency = freq;
80 self
81 }
82}
83
84pub struct DataParallelTrainer {
86 config: DistributedConfig,
87 local_gradients: Arc<Mutex<Vec<Vec<f32>>>>,
88 global_gradients: Arc<Mutex<Vec<Vec<f32>>>>,
89 iteration: usize,
90}
91
92impl DataParallelTrainer {
93 pub fn new(config: DistributedConfig) -> Self {
94 DataParallelTrainer {
95 config,
96 local_gradients: Arc::new(Mutex::new(Vec::new())),
97 global_gradients: Arc::new(Mutex::new(Vec::new())),
98 iteration: 0,
99 }
100 }
101
102 pub fn split_data(&self, data: &Tensor, labels: &Tensor) -> (Tensor, Tensor) {
104 let n_samples = data.dims()[0];
105 let n_features = data.dims()[1];
106 let samples_per_worker = n_samples / self.config.world_size;
107
108 let start_idx = self.config.rank * samples_per_worker;
109 let end_idx = if self.config.rank == self.config.world_size - 1 {
110 n_samples
111 } else {
112 (self.config.rank + 1) * samples_per_worker
113 };
114
115 let data_slice = &data.data_f32()[start_idx * n_features..end_idx * n_features];
116 let labels_slice = &labels.data_f32()[start_idx..end_idx];
117
118 let local_data = Tensor::from_slice(data_slice, &[end_idx - start_idx, n_features]).unwrap();
119 let local_labels = Tensor::from_slice(labels_slice, &[end_idx - start_idx]).unwrap();
120
121 (local_data, local_labels)
122 }
123
124 pub fn accumulate_gradients(&mut self, gradients: Vec<Vec<f32>>) {
126 let mut local_grads = self.local_gradients.lock().unwrap();
127 *local_grads = gradients;
128 }
129
130 pub fn sync_gradients(&mut self) -> Vec<Vec<f32>> {
132 self.iteration += 1;
133
134 if self.iteration % self.config.sync_frequency != 0 {
136 return self.local_gradients.lock().unwrap().clone();
137 }
138
139 match self.config.backend {
140 CommunicationBackend::Threads => self.sync_gradients_threads(),
141 CommunicationBackend::Processes => self.sync_gradients_processes(),
142 CommunicationBackend::MPI => self.sync_gradients_mpi(),
143 }
144 }
145
146 fn sync_gradients_threads(&self) -> Vec<Vec<f32>> {
147 let local_grads = self.local_gradients.lock().unwrap();
149 let mut global_grads = self.global_gradients.lock().unwrap();
150
151 if global_grads.is_empty() {
152 *global_grads = local_grads.clone();
153 } else {
154 for (global_layer, local_layer) in global_grads.iter_mut().zip(local_grads.iter()) {
156 for (g, l) in global_layer.iter_mut().zip(local_layer.iter()) {
157 match self.config.gradient_aggregation {
158 GradientAggregation::Average => {
159 *g = (*g * (self.config.world_size - 1) as f32 + l) / self.config.world_size as f32;
160 }
161 GradientAggregation::Sum => {
162 *g += l;
163 }
164 GradientAggregation::WeightedAverage => {
165 *g = (*g + l) / 2.0;
166 }
167 }
168 }
169 }
170 }
171
172 global_grads.clone()
173 }
174
175 fn sync_gradients_processes(&self) -> Vec<Vec<f32>> {
176 self.local_gradients.lock().unwrap().clone()
179 }
180
181 fn sync_gradients_mpi(&self) -> Vec<Vec<f32>> {
182 self.local_gradients.lock().unwrap().clone()
185 }
186
187 pub fn all_reduce(&self, gradients: &[Vec<f32>]) -> Vec<Vec<f32>> {
189 let mut reduced = gradients.to_vec();
191
192 match self.config.gradient_aggregation {
193 GradientAggregation::Average => {
194 for layer in &mut reduced {
195 for grad in layer {
196 *grad /= self.config.world_size as f32;
197 }
198 }
199 }
200 GradientAggregation::Sum => {
201 }
203 GradientAggregation::WeightedAverage => {
204 for layer in &mut reduced {
205 for grad in layer {
206 *grad /= self.config.world_size as f32;
207 }
208 }
209 }
210 }
211
212 reduced
213 }
214
215 pub fn broadcast_parameters(&self, parameters: &[Vec<f32>]) -> Vec<Vec<f32>> {
217 if self.config.rank == 0 {
218 parameters.to_vec()
220 } else {
221 parameters.to_vec()
224 }
225 }
226
227 pub fn is_master(&self) -> bool {
228 self.config.rank == 0
229 }
230}
231
232pub struct DistributedDataLoader {
234 pub batch_size: usize,
235 pub world_size: usize,
236 pub rank: usize,
237 pub shuffle: bool,
238 pub drop_last: bool,
239}
240
241impl DistributedDataLoader {
242 pub fn new(batch_size: usize, world_size: usize, rank: usize) -> Self {
243 DistributedDataLoader {
244 batch_size,
245 world_size,
246 rank,
247 shuffle: true,
248 drop_last: false,
249 }
250 }
251
252 pub fn shuffle(mut self, shuffle: bool) -> Self {
253 self.shuffle = shuffle;
254 self
255 }
256
257 pub fn drop_last(mut self, drop: bool) -> Self {
258 self.drop_last = drop;
259 self
260 }
261
262 pub fn get_batches(&self, data: &Tensor, labels: &Tensor) -> Vec<(Tensor, Tensor)> {
264 let n_samples = data.dims()[0];
265 let n_features = data.dims()[1];
266
267 let samples_per_worker = n_samples / self.world_size;
269 let start_idx = self.rank * samples_per_worker;
270 let end_idx = if self.rank == self.world_size - 1 {
271 n_samples
272 } else {
273 (self.rank + 1) * samples_per_worker
274 };
275
276 let worker_samples = end_idx - start_idx;
277 let n_batches = if self.drop_last {
278 worker_samples / self.batch_size
279 } else {
280 (worker_samples + self.batch_size - 1) / self.batch_size
281 };
282
283 let mut batches = Vec::new();
284 let data_slice = data.data_f32();
285 let labels_slice = labels.data_f32();
286
287 for batch_idx in 0..n_batches {
288 let batch_start = start_idx + batch_idx * self.batch_size;
289 let batch_end = (batch_start + self.batch_size).min(end_idx);
290 let batch_size = batch_end - batch_start;
291
292 let batch_data: Vec<f32> = (batch_start..batch_end)
293 .flat_map(|i| data_slice[i * n_features..(i + 1) * n_features].to_vec())
294 .collect();
295 let batch_labels: Vec<f32> = (batch_start..batch_end)
296 .map(|i| labels_slice[i])
297 .collect();
298
299 let data_tensor = Tensor::from_slice(&batch_data, &[batch_size, n_features]).unwrap();
300 let labels_tensor = Tensor::from_slice(&batch_labels, &[batch_size]).unwrap();
301
302 batches.push((data_tensor, labels_tensor));
303 }
304
305 batches
306 }
307}
308
309pub struct GradientCompression {
311 pub method: CompressionMethod,
312 pub compression_ratio: f32,
313}
314
315#[derive(Clone, Copy, Debug)]
316pub enum CompressionMethod {
317 None,
319 TopK,
321 Random,
323 Quantization,
325}
326
327impl GradientCompression {
328 pub fn new(method: CompressionMethod) -> Self {
329 GradientCompression {
330 method,
331 compression_ratio: 0.1,
332 }
333 }
334
335 pub fn compression_ratio(mut self, ratio: f32) -> Self {
336 self.compression_ratio = ratio;
337 self
338 }
339
340 pub fn compress(&self, gradients: &[f32]) -> (Vec<usize>, Vec<f32>) {
341 match self.method {
342 CompressionMethod::None => {
343 let indices: Vec<usize> = (0..gradients.len()).collect();
344 (indices, gradients.to_vec())
345 }
346 CompressionMethod::TopK => self.compress_topk(gradients),
347 CompressionMethod::Random => self.compress_random(gradients),
348 CompressionMethod::Quantization => self.compress_quantize(gradients),
349 }
350 }
351
352 fn compress_topk(&self, gradients: &[f32]) -> (Vec<usize>, Vec<f32>) {
353 let k = (gradients.len() as f32 * self.compression_ratio) as usize;
354 let mut indexed: Vec<(usize, f32)> = gradients.iter()
355 .enumerate()
356 .map(|(i, &g)| (i, g.abs()))
357 .collect();
358
359 indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
360 indexed.truncate(k);
361
362 let indices: Vec<usize> = indexed.iter().map(|(i, _)| *i).collect();
363 let values: Vec<f32> = indexed.iter().map(|(i, _)| gradients[*i]).collect();
364
365 (indices, values)
366 }
367
368 fn compress_random(&self, gradients: &[f32]) -> (Vec<usize>, Vec<f32>) {
369 use rand::prelude::*;
370 let mut rng = thread_rng();
371 let k = (gradients.len() as f32 * self.compression_ratio) as usize;
372
373 let mut indices: Vec<usize> = (0..gradients.len()).collect();
374 indices.shuffle(&mut rng);
375 indices.truncate(k);
376
377 let values: Vec<f32> = indices.iter().map(|&i| gradients[i]).collect();
378
379 (indices, values)
380 }
381
382 fn compress_quantize(&self, gradients: &[f32]) -> (Vec<usize>, Vec<f32>) {
383 let max_abs = gradients.iter().map(|g| g.abs()).fold(0.0f32, f32::max);
385 let scale = max_abs / 127.0;
386
387 let quantized: Vec<f32> = gradients.iter()
388 .map(|&g| (g / scale).round() * scale)
389 .collect();
390
391 let indices: Vec<usize> = (0..gradients.len()).collect();
392 (indices, quantized)
393 }
394
395 pub fn decompress(&self, indices: &[usize], values: &[f32], size: usize) -> Vec<f32> {
396 let mut decompressed = vec![0.0f32; size];
397 for (&idx, &val) in indices.iter().zip(values.iter()) {
398 if idx < size {
399 decompressed[idx] = val;
400 }
401 }
402 decompressed
403 }
404}
405
406pub struct RingAllReduce {
408 pub world_size: usize,
409 pub rank: usize,
410}
411
412impl RingAllReduce {
413 pub fn new(world_size: usize, rank: usize) -> Self {
414 RingAllReduce { world_size, rank }
415 }
416
417 pub fn all_reduce(&self, gradients: &[Vec<f32>]) -> Vec<Vec<f32>> {
419 let mut result = gradients.to_vec();
422
423 for layer in &mut result {
425 let sum: f32 = layer.iter().sum();
426 let avg = sum / self.world_size as f32;
427 for grad in layer {
428 *grad = avg;
429 }
430 }
431
432 result
433 }
434
435 #[allow(dead_code)]
436 fn get_next_rank(&self) -> usize {
437 (self.rank + 1) % self.world_size
438 }
439
440 #[allow(dead_code)]
441 fn get_prev_rank(&self) -> usize {
442 (self.rank + self.world_size - 1) % self.world_size
443 }
444}
445pub struct DistributedOptimizer<O> {
447 #[allow(dead_code)]
448 optimizer: O,
449 trainer: DataParallelTrainer,
450 compression: Option<GradientCompression>,
451}
452
453impl<O> DistributedOptimizer<O> {
454 pub fn new(optimizer: O, config: DistributedConfig) -> Self {
455 DistributedOptimizer {
456 optimizer,
457 trainer: DataParallelTrainer::new(config),
458 compression: None,
459 }
460 }
461
462 pub fn with_compression(mut self, compression: GradientCompression) -> Self {
463 self.compression = Some(compression);
464 self
465 }
466
467 pub fn step(&mut self, params: &mut [f32], local_grads: &[f32]) {
468 let grads_to_sync = if let Some(ref compression) = self.compression {
470 let (indices, values) = compression.compress(local_grads);
471 compression.decompress(&indices, &values, local_grads.len())
472 } else {
473 local_grads.to_vec()
474 };
475
476 let grad_vec = vec![grads_to_sync];
478 self.trainer.accumulate_gradients(grad_vec);
479 let synced_grads = self.trainer.sync_gradients();
480
481 if !synced_grads.is_empty() && !synced_grads[0].is_empty() {
483 for (p, g) in params.iter_mut().zip(synced_grads[0].iter()) {
485 *p -= 0.01 * g; }
487 }
488 }
489
490 pub fn is_master(&self) -> bool {
491 self.trainer.is_master()
492 }
493}
494
495#[cfg(test)]
496mod tests {
497 use super::*;
498
499 #[test]
500 fn test_data_parallel_trainer() {
501 let config = DistributedConfig::new(2, 0);
502 let trainer = DataParallelTrainer::new(config);
503
504 let data = Tensor::from_slice(&vec![1.0f32; 100], &[10, 10]).unwrap();
505 let labels = Tensor::from_slice(&vec![0.0f32; 10], &[10]).unwrap();
506
507 let (local_data, _local_labels) = trainer.split_data(&data, &labels);
508 assert_eq!(local_data.dims()[0], 5); }
510
511 #[test]
512 fn test_distributed_data_loader() {
513 let loader = DistributedDataLoader::new(2, 2, 0);
514
515 let data = Tensor::from_slice(&vec![1.0f32; 100], &[10, 10]).unwrap();
516 let labels = Tensor::from_slice(&vec![0.0f32; 10], &[10]).unwrap();
517
518 let batches = loader.get_batches(&data, &labels);
519 assert!(batches.len() > 0);
520 }
521
522 #[test]
523 fn test_gradient_compression() {
524 let compression = GradientCompression::new(CompressionMethod::TopK)
525 .compression_ratio(0.5);
526
527 let gradients = vec![1.0, 2.0, 3.0, 4.0, 5.0];
528 let (indices, values) = compression.compress(&gradients);
529
530 assert!(indices.len() <= (gradients.len() as f32 * 0.5) as usize + 1);
531
532 let decompressed = compression.decompress(&indices, &values, gradients.len());
533 assert_eq!(decompressed.len(), gradients.len());
534 }
535
536 #[test]
537 fn test_ring_all_reduce() {
538 let ring = RingAllReduce::new(4, 0);
539 let gradients = vec![vec![1.0, 2.0, 3.0]];
540
541 let result = ring.all_reduce(&gradients);
542 assert_eq!(result.len(), 1);
543 assert_eq!(result[0].len(), 3);
544 }
545}
546
547