1use ghostflow_core::tensor::Tensor;
6use std::sync::{Arc, Mutex};
7use std::collections::HashMap;
8
9#[derive(Clone, Copy, Debug, PartialEq)]
11pub enum DistributedBackend {
12 NCCL,
14 Gloo,
16 MPI,
18}
19
20#[derive(Clone, Debug)]
22pub struct DistributedConfig {
23 pub backend: DistributedBackend,
25 pub world_size: usize,
27 pub rank: usize,
29 pub master_addr: String,
31 pub master_port: u16,
33}
34
35impl Default for DistributedConfig {
36 fn default() -> Self {
37 Self {
38 backend: DistributedBackend::NCCL,
39 world_size: 1,
40 rank: 0,
41 master_addr: "localhost".to_string(),
42 master_port: 29500,
43 }
44 }
45}
46
47pub struct DataParallel {
52 config: DistributedConfig,
53 device_ids: Vec<usize>,
54 gradient_buckets: Arc<Mutex<HashMap<String, Vec<Tensor>>>>,
55}
56
57impl DataParallel {
58 pub fn new(config: DistributedConfig, device_ids: Vec<usize>) -> Self {
59 Self {
60 config,
61 device_ids,
62 gradient_buckets: Arc::new(Mutex::new(HashMap::new())),
63 }
64 }
65
66 pub fn split_batch(&self, batch: &Tensor) -> Vec<Tensor> {
68 let batch_size = batch.shape().dims()[0];
69 let per_device = batch_size / self.device_ids.len();
70
71 let mut splits = Vec::new();
72 for i in 0..self.device_ids.len() {
73 let start = i * per_device;
74 let end = if i == self.device_ids.len() - 1 {
75 batch_size
76 } else {
77 (i + 1) * per_device
78 };
79
80 let split = self.slice_batch(batch, start, end);
83 splits.push(split);
84 }
85
86 splits
87 }
88
89 fn slice_batch(&self, batch: &Tensor, start: usize, end: usize) -> Tensor {
90 let batch_data = batch.storage().as_slice::<f32>();
93 let dims = batch.shape().dims();
94 let row_size = dims[1..].iter().product::<usize>();
95
96 let slice_data: Vec<f32> = batch_data[start * row_size..end * row_size].to_vec();
97 let mut new_dims = dims.to_vec();
98 new_dims[0] = end - start;
99
100 Tensor::from_slice(&slice_data, &new_dims).unwrap()
101 }
102
103 pub fn all_reduce_gradients(&self, gradients: &HashMap<String, Tensor>) -> HashMap<String, Tensor> {
105 let mut averaged = HashMap::new();
106
107 for (name, grad) in gradients {
108 let averaged_grad = self.average_gradient(grad);
111 averaged.insert(name.clone(), averaged_grad);
112 }
113
114 averaged
115 }
116
117 fn average_gradient(&self, grad: &Tensor) -> Tensor {
118 let scale = 1.0 / self.config.world_size as f32;
121
122 let grad_data = grad.storage().as_slice::<f32>();
123 let scaled_data: Vec<f32> = grad_data.iter().map(|&x| x * scale).collect();
124
125 Tensor::from_slice(&scaled_data, grad.shape().dims()).unwrap()
126 }
127
128 pub fn broadcast_parameters(&self, parameters: &HashMap<String, Tensor>) -> HashMap<String, Tensor> {
130 if self.config.rank == 0 {
131 parameters.clone()
133 } else {
134 parameters.clone()
137 }
138 }
139}
140
141pub struct ModelParallel {
146 config: DistributedConfig,
147 layer_placement: HashMap<String, usize>,
148}
149
150impl ModelParallel {
151 pub fn new(config: DistributedConfig) -> Self {
152 Self {
153 config,
154 layer_placement: HashMap::new(),
155 }
156 }
157
158 pub fn place_layer(&mut self, layer_name: &str, device_id: usize) {
160 self.layer_placement.insert(layer_name.to_string(), device_id);
161 }
162
163 pub fn get_device(&self, layer_name: &str) -> Option<usize> {
165 self.layer_placement.get(layer_name).copied()
166 }
167
168 pub fn auto_place_layers(&mut self, layer_names: &[String], num_devices: usize) {
170 let layers_per_device = (layer_names.len() + num_devices - 1) / num_devices;
171
172 for (i, name) in layer_names.iter().enumerate() {
173 let device = i / layers_per_device;
174 self.place_layer(name, device.min(num_devices - 1));
175 }
176 }
177
178 pub fn transfer(&self, tensor: &Tensor, _from_device: usize, _to_device: usize) -> Tensor {
180 tensor.clone()
183 }
184}
185
186pub struct GradientAccumulator {
191 accumulation_steps: usize,
192 current_step: usize,
193 accumulated_gradients: HashMap<String, Tensor>,
194}
195
196impl GradientAccumulator {
197 pub fn new(accumulation_steps: usize) -> Self {
198 Self {
199 accumulation_steps,
200 current_step: 0,
201 accumulated_gradients: HashMap::new(),
202 }
203 }
204
205 pub fn accumulate(&mut self, gradients: &HashMap<String, Tensor>) {
207 for (name, grad) in gradients {
208 let should_add = self.accumulated_gradients.contains_key(name);
209
210 if should_add {
211 let accumulated = self.accumulated_gradients.get(name).unwrap();
213 let sum = self.add_tensors(accumulated, grad);
214 self.accumulated_gradients.insert(name.clone(), sum);
215 } else {
216 self.accumulated_gradients.insert(name.clone(), grad.clone());
218 }
219 }
220
221 self.current_step += 1;
222 }
223
224 fn add_tensors(&self, a: &Tensor, b: &Tensor) -> Tensor {
225 let a_data = a.storage().as_slice::<f32>();
226 let b_data = b.storage().as_slice::<f32>();
227
228 let sum: Vec<f32> = a_data.iter().zip(b_data.iter())
229 .map(|(x, y)| x + y)
230 .collect();
231
232 Tensor::from_slice(&sum, a.shape().dims()).unwrap()
233 }
234
235 pub fn should_update(&self) -> bool {
237 self.current_step >= self.accumulation_steps
238 }
239
240 pub fn get_and_reset(&mut self) -> HashMap<String, Tensor> {
242 let gradients = self.accumulated_gradients.clone();
243 self.accumulated_gradients.clear();
244 self.current_step = 0;
245
246 let scale = 1.0 / self.accumulation_steps as f32;
248 gradients.into_iter()
249 .map(|(name, grad)| {
250 let scaled = self.scale_tensor(&grad, scale);
251 (name, scaled)
252 })
253 .collect()
254 }
255
256 fn scale_tensor(&self, tensor: &Tensor, scale: f32) -> Tensor {
257 let data = tensor.storage().as_slice::<f32>();
258 let scaled: Vec<f32> = data.iter().map(|&x| x * scale).collect();
259 Tensor::from_slice(&scaled, tensor.shape().dims()).unwrap()
260 }
261
262 pub fn reset(&mut self) {
264 self.accumulated_gradients.clear();
265 self.current_step = 0;
266 }
267}
268
269pub struct DistributedDataParallel {
274 data_parallel: DataParallel,
275 gradient_accumulator: Option<GradientAccumulator>,
276 find_unused_parameters: bool,
277}
278
279impl DistributedDataParallel {
280 pub fn new(
281 config: DistributedConfig,
282 device_ids: Vec<usize>,
283 gradient_accumulation_steps: Option<usize>,
284 ) -> Self {
285 let gradient_accumulator = gradient_accumulation_steps
286 .map(GradientAccumulator::new);
287
288 Self {
289 data_parallel: DataParallel::new(config, device_ids),
290 gradient_accumulator,
291 find_unused_parameters: false,
292 }
293 }
294
295 pub fn find_unused_parameters(mut self, enabled: bool) -> Self {
297 self.find_unused_parameters = enabled;
298 self
299 }
300
301 pub fn forward(&self, batch: &Tensor) -> Vec<Tensor> {
303 self.data_parallel.split_batch(batch)
304 }
305
306 pub fn backward(&mut self, gradients: &HashMap<String, Tensor>) -> Option<HashMap<String, Tensor>> {
308 let reduced_gradients = self.data_parallel.all_reduce_gradients(gradients);
310
311 if let Some(ref mut accumulator) = self.gradient_accumulator {
313 accumulator.accumulate(&reduced_gradients);
314
315 if accumulator.should_update() {
316 Some(accumulator.get_and_reset())
317 } else {
318 None
319 }
320 } else {
321 Some(reduced_gradients)
322 }
323 }
324
325 pub fn sync_parameters(&self, parameters: &HashMap<String, Tensor>) -> HashMap<String, Tensor> {
327 self.data_parallel.broadcast_parameters(parameters)
328 }
329}
330
331pub struct PipelineParallel {
335 num_stages: usize,
336 num_micro_batches: usize,
337 current_stage: usize,
338}
339
340impl PipelineParallel {
341 pub fn new(num_stages: usize, num_micro_batches: usize) -> Self {
342 Self {
343 num_stages,
344 num_micro_batches,
345 current_stage: 0,
346 }
347 }
348
349 pub fn create_micro_batches(&self, batch: &Tensor) -> Vec<Tensor> {
351 let batch_size = batch.shape().dims()[0];
352 let micro_batch_size = batch_size / self.num_micro_batches;
353
354 let mut micro_batches = Vec::new();
355 for i in 0..self.num_micro_batches {
356 let start = i * micro_batch_size;
357 let end = if i == self.num_micro_batches - 1 {
358 batch_size
359 } else {
360 (i + 1) * micro_batch_size
361 };
362
363 let micro_batch = self.slice_batch(batch, start, end);
364 micro_batches.push(micro_batch);
365 }
366
367 micro_batches
368 }
369
370 fn slice_batch(&self, batch: &Tensor, start: usize, end: usize) -> Tensor {
371 let batch_data = batch.storage().as_slice::<f32>();
372 let dims = batch.shape().dims();
373 let row_size = dims[1..].iter().product::<usize>();
374
375 let slice_data: Vec<f32> = batch_data[start * row_size..end * row_size].to_vec();
376 let mut new_dims = dims.to_vec();
377 new_dims[0] = end - start;
378
379 Tensor::from_slice(&slice_data, &new_dims).unwrap()
380 }
381
382 pub fn current_stage(&self) -> usize {
384 self.current_stage
385 }
386
387 pub fn next_stage(&mut self) {
389 self.current_stage = (self.current_stage + 1) % self.num_stages;
390 }
391}
392
393#[cfg(test)]
394mod tests {
395 use super::*;
396
397 #[test]
398 fn test_data_parallel_split_batch() {
399 let config = DistributedConfig {
400 world_size: 2,
401 rank: 0,
402 ..Default::default()
403 };
404 let dp = DataParallel::new(config, vec![0, 1]);
405
406 let batch = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[4, 2]).unwrap();
407 let splits = dp.split_batch(&batch);
408
409 assert_eq!(splits.len(), 2);
410 assert_eq!(splits[0].shape().dims()[0], 2);
411 assert_eq!(splits[1].shape().dims()[0], 2);
412 }
413
414 #[test]
415 fn test_gradient_accumulation() {
416 let mut accumulator = GradientAccumulator::new(4);
417
418 let grad1 = Tensor::from_slice(&[1.0f32, 2.0, 3.0], &[3]).unwrap();
419 let grad2 = Tensor::from_slice(&[2.0f32, 3.0, 4.0], &[3]).unwrap();
420
421 let mut grads = HashMap::new();
422 grads.insert("layer1".to_string(), grad1);
423
424 accumulator.accumulate(&grads);
425 assert!(!accumulator.should_update());
426
427 accumulator.accumulate(&grads);
428 accumulator.accumulate(&grads);
429 accumulator.accumulate(&grads);
430 assert!(accumulator.should_update());
431
432 let final_grads = accumulator.get_and_reset();
433 assert!(final_grads.contains_key("layer1"));
434 assert_eq!(accumulator.current_step, 0);
435 }
436
437 #[test]
438 fn test_model_parallel_placement() {
439 let config = DistributedConfig::default();
440 let mut mp = ModelParallel::new(config);
441
442 mp.place_layer("layer1", 0);
443 mp.place_layer("layer2", 1);
444 mp.place_layer("layer3", 0);
445
446 assert_eq!(mp.get_device("layer1"), Some(0));
447 assert_eq!(mp.get_device("layer2"), Some(1));
448 assert_eq!(mp.get_device("layer3"), Some(0));
449 }
450
451 #[test]
452 fn test_auto_layer_placement() {
453 let config = DistributedConfig::default();
454 let mut mp = ModelParallel::new(config);
455
456 let layers = vec![
457 "layer1".to_string(),
458 "layer2".to_string(),
459 "layer3".to_string(),
460 "layer4".to_string(),
461 ];
462
463 mp.auto_place_layers(&layers, 2);
464
465 assert_eq!(mp.get_device("layer1"), Some(0));
466 assert_eq!(mp.get_device("layer2"), Some(0));
467 assert_eq!(mp.get_device("layer3"), Some(1));
468 assert_eq!(mp.get_device("layer4"), Some(1));
469 }
470
471 #[test]
472 fn test_ddp_forward_backward() {
473 let config = DistributedConfig {
474 world_size: 2,
475 rank: 0,
476 ..Default::default()
477 };
478 let mut ddp = DistributedDataParallel::new(config, vec![0, 1], Some(2));
479
480 let batch = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
481 let splits = ddp.forward(&batch);
482 assert_eq!(splits.len(), 2);
483
484 let mut gradients = HashMap::new();
485 gradients.insert("layer1".to_string(), Tensor::from_slice(&[1.0f32, 2.0], &[2]).unwrap());
486
487 let result = ddp.backward(&gradients);
489 assert!(result.is_none());
490
491 let result = ddp.backward(&gradients);
493 assert!(result.is_some());
494 }
495
496 #[test]
497 fn test_pipeline_parallel() {
498 let pp = PipelineParallel::new(4, 8);
499
500 let batch = Tensor::from_slice(&(0..32).map(|x| x as f32).collect::<Vec<_>>(), &[8, 4]).unwrap();
501 let micro_batches = pp.create_micro_batches(&batch);
502
503 assert_eq!(micro_batches.len(), 8);
504 assert_eq!(micro_batches[0].shape().dims()[0], 1);
505 }
506
507 #[test]
508 fn test_all_reduce_gradients() {
509 let config = DistributedConfig {
510 world_size: 4,
511 rank: 0,
512 ..Default::default()
513 };
514 let dp = DataParallel::new(config, vec![0, 1, 2, 3]);
515
516 let mut gradients = HashMap::new();
517 gradients.insert(
518 "layer1".to_string(),
519 Tensor::from_slice(&[4.0f32, 8.0, 12.0], &[3]).unwrap()
520 );
521
522 let reduced = dp.all_reduce_gradients(&gradients);
523 let grad_data = reduced.get("layer1").unwrap().storage().as_slice::<f32>();
524
525 assert!((grad_data[0] - 1.0).abs() < 0.01);
527 assert!((grad_data[1] - 2.0).abs() < 0.01);
528 assert!((grad_data[2] - 3.0).abs() < 0.01);
529 }
530}