1use scirs2_core::gpu::{GpuBuffer, GpuContext, GpuDataType, GpuKernelHandle};
4use scirs2_core::ndarray::{ArrayBase, Data, DataMut, Dimension};
5use scirs2_core::numeric::Float;
6use std::marker::PhantomData;
7use std::sync::Arc;
8
9use crate::backends::GpuBackend;
10use crate::GpuOptimError;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
14pub enum SyncStrategy {
15 RingAllReduce,
17 TreeAllReduce,
19 HierarchicalAllReduce,
21 PipelineParallel,
23}
24
25#[derive(Debug, Clone)]
27pub struct MultiGpuConfig {
28 pub num_gpus: usize,
30 pub rank: usize,
32 pub sync_strategy: SyncStrategy,
34 pub gradient_compression: bool,
36 pub compression_ratio: f32,
38 pub local_group_size: usize,
40 pub adaptive_communication: bool,
42 pub bandwidth_monitor_interval: usize,
44 pub async_param_updates: bool,
46 pub communication_timeout_ms: u64,
48 pub error_correction: bool,
50 pub pipeline_depth: usize,
52}
53
54impl Default for MultiGpuConfig {
55 fn default() -> Self {
56 Self {
57 num_gpus: 1,
58 rank: 0,
59 sync_strategy: SyncStrategy::RingAllReduce,
60 gradient_compression: false,
61 compression_ratio: 0.1, local_group_size: 4,
63 adaptive_communication: true,
64 bandwidth_monitor_interval: 100,
65 async_param_updates: false,
66 communication_timeout_ms: 5000,
67 error_correction: true,
68 pipeline_depth: 2,
69 }
70 }
71}
72
73pub struct MultiGpuSync<A: Float + GpuDataType> {
75 context: Arc<GpuContext>,
77 config: MultiGpuConfig,
79 sync_kernels: SyncKernels,
81 workspace: WorkspaceBuffers<A>,
83 perf_monitor: CommunicationPerformanceMonitor,
85 adaptive_selector: AdaptiveCommunicationSelector,
87 async_handles: Vec<AsyncCommunicationHandle>,
89 step_counter: usize,
91 _phantom: PhantomData<A>,
93}
94
95struct SyncKernels {
97 ring_allreduce: Option<Arc<GpuKernelHandle>>,
98 tree_allreduce: Option<Arc<GpuKernelHandle>>,
99 hierarchical_allreduce: Option<Arc<GpuKernelHandle>>,
100 compress_gradients: Option<Arc<GpuKernelHandle>>,
101 decompress_gradients: Option<Arc<GpuKernelHandle>>,
102}
103
104struct WorkspaceBuffers<A: Float + GpuDataType> {
106 recv_buffer: Option<GpuBuffer<A>>,
107 workspace: Option<GpuBuffer<A>>,
108 compressed_values: Option<GpuBuffer<A>>,
109 compressed_indices: Option<GpuBuffer<i32>>,
110 error_feedback: Option<GpuBuffer<A>>,
111}
112
113#[derive(Debug, Clone)]
115pub struct CommunicationPerformanceMonitor {
116 total_comm_time_us: u64,
118 total_data_bytes: u64,
120 comm_operations: usize,
122 bandwidth_history: std::collections::VecDeque<f64>,
124 strategy_performance: std::collections::HashMap<SyncStrategy, StrategyPerformanceMetrics>,
126 optimal_strategy: SyncStrategy,
128}
129
130impl CommunicationPerformanceMonitor {
131 fn new() -> Self {
132 Self {
133 total_comm_time_us: 0,
134 total_data_bytes: 0,
135 comm_operations: 0,
136 bandwidth_history: std::collections::VecDeque::with_capacity(1000),
137 strategy_performance: std::collections::HashMap::new(),
138 optimal_strategy: SyncStrategy::RingAllReduce,
139 }
140 }
141
142 fn record_communication(&mut self, strategy: SyncStrategy, data_bytes: u64, timeus: u64) {
143 self.total_comm_time_us += timeus;
144 self.total_data_bytes += data_bytes;
145 self.comm_operations += 1;
146
147 let bandwidth_gb_s = (data_bytes as f64) / (timeus as f64 / 1_000_000.0) / 1e9;
148 self.bandwidth_history.push_back(bandwidth_gb_s);
149
150 if self.bandwidth_history.len() > 1000 {
151 self.bandwidth_history.pop_front();
152 }
153
154 let metrics = self
156 .strategy_performance
157 .entry(strategy)
158 .or_insert_with(StrategyPerformanceMetrics::new);
159 metrics.update(bandwidth_gb_s, timeus);
160 }
161
162 fn get_average_bandwidth(&self) -> f64 {
163 if self.total_comm_time_us == 0 {
164 0.0
165 } else {
166 (self.total_data_bytes as f64) / (self.total_comm_time_us as f64 / 1_000_000.0) / 1e9
167 }
168 }
169
170 fn get_optimal_strategy(&self, tensorsize: usize) -> SyncStrategy {
171 let mut best_strategy = SyncStrategy::RingAllReduce;
172 let mut best_score = 0.0;
173
174 for (strategy, metrics) in &self.strategy_performance {
175 let score = metrics.calculate_score(tensorsize);
176 if score > best_score {
177 best_score = score;
178 best_strategy = *strategy;
179 }
180 }
181
182 best_strategy
183 }
184}
185
186#[derive(Debug, Clone)]
188struct StrategyPerformanceMetrics {
189 bandwidth_samples: std::collections::VecDeque<f64>,
190 latency_samples: std::collections::VecDeque<u64>,
191 tensor_sizes: std::collections::VecDeque<usize>,
192 efficiency_score: f64,
193}
194
195impl StrategyPerformanceMetrics {
196 fn new() -> Self {
197 Self {
198 bandwidth_samples: std::collections::VecDeque::with_capacity(100),
199 latency_samples: std::collections::VecDeque::with_capacity(100),
200 tensor_sizes: std::collections::VecDeque::with_capacity(100),
201 efficiency_score: 0.0,
202 }
203 }
204
205 fn update(&mut self, bandwidth_gb_s: f64, latencyus: u64) {
206 self.bandwidth_samples.push_back(bandwidth_gb_s);
207 self.latency_samples.push_back(latencyus);
208
209 if self.bandwidth_samples.len() > 100 {
210 self.bandwidth_samples.pop_front();
211 self.latency_samples.pop_front();
212 }
213
214 let avg_bandwidth =
216 self.bandwidth_samples.iter().sum::<f64>() / self.bandwidth_samples.len() as f64;
217 let avg_latency =
218 self.latency_samples.iter().sum::<u64>() as f64 / self.latency_samples.len() as f64;
219
220 self.efficiency_score = avg_bandwidth / (avg_latency / 1000.0); }
222
223 fn calculate_score(&self, tensorsize: usize) -> f64 {
224 let size_factor = if tensorsize > 1000000 { 2.0 } else { 1.0 }; self.efficiency_score * size_factor
227 }
228}
229
230#[derive(Debug)]
232pub struct AdaptiveCommunicationSelector {
233 current_strategy: SyncStrategy,
235 switch_cooldown: usize,
237 last_switch_step: usize,
239 evaluation_window: usize,
241 performance_threshold: f64,
243}
244
245impl AdaptiveCommunicationSelector {
246 fn new() -> Self {
247 Self {
248 current_strategy: SyncStrategy::RingAllReduce,
249 switch_cooldown: 50,
250 last_switch_step: 0,
251 evaluation_window: 20,
252 performance_threshold: 1.2, }
254 }
255
256 fn should_evaluate_strategy(&self, currentstep: usize) -> bool {
257 currentstep - self.last_switch_step >= self.switch_cooldown
258 }
259
260 fn evaluate_and_switch(
261 &mut self,
262 monitor: &CommunicationPerformanceMonitor,
263 tensor_size: usize,
264 current_step: usize,
265 ) -> Option<SyncStrategy> {
266 if !self.should_evaluate_strategy(current_step) {
267 return None;
268 }
269
270 let optimal_strategy = monitor.get_optimal_strategy(tensor_size);
271
272 if optimal_strategy != self.current_strategy {
273 if let (Some(current_metrics), Some(optimal_metrics)) = (
275 monitor.strategy_performance.get(&self.current_strategy),
276 monitor.strategy_performance.get(&optimal_strategy),
277 ) {
278 let performance_ratio =
279 optimal_metrics.efficiency_score / current_metrics.efficiency_score;
280
281 if performance_ratio >= self.performance_threshold {
282 self.current_strategy = optimal_strategy;
283 self.last_switch_step = current_step;
284 return Some(optimal_strategy);
285 }
286 }
287 }
288
289 None
290 }
291}
292
293#[derive(Debug)]
295pub struct AsyncCommunicationHandle {
296 id: usize,
298 start_time: std::time::Instant,
300 expected_completion: std::time::Duration,
302 strategy: SyncStrategy,
304 data_size: usize,
306 status: AsyncCommStatus,
308}
309
310#[derive(Debug, Clone, Copy, PartialEq)]
311pub enum AsyncCommStatus {
312 Pending,
313 InProgress,
314 Completed,
315 Failed,
316 Timeout,
317}
318
319#[derive(Debug, Clone)]
321pub struct CommunicationPerformanceStats {
322 pub average_bandwidth_gb_s: f64,
323 pub total_operations: usize,
324 pub total_data_transferred_gb: f64,
325 pub current_strategy: SyncStrategy,
326 pub pending_async_ops: usize,
327 pub step_count: usize,
328}
329
330impl<A: Float + GpuDataType + Send + Sync> MultiGpuSync<A> {
331 pub fn new(
333 context: Arc<GpuContext>,
334 config: MultiGpuConfig,
335 max_param_size: usize,
336 ) -> Result<Self, GpuOptimError> {
337 let sync_kernels = Self::load_sync_kernels(&context, &config)?;
339
340 let workspace = Self::allocate_workspace(&context, &config, max_param_size)?;
342
343 let perf_monitor = CommunicationPerformanceMonitor::new();
345 let adaptive_selector = AdaptiveCommunicationSelector::new();
346 let async_handles = Vec::with_capacity(config.pipeline_depth);
347
348 Ok(Self {
349 context,
350 config,
351 sync_kernels,
352 workspace,
353 perf_monitor,
354 adaptive_selector,
355 async_handles,
356 step_counter: 0,
357 _phantom: PhantomData,
358 })
359 }
360
361 pub fn sync_gradients<S, D>(
363 &mut self,
364 gradients: &mut ArrayBase<S, D>,
365 ) -> Result<(), GpuOptimError>
366 where
367 S: DataMut<Elem = A>,
368 D: Dimension,
369 {
370 self.step_counter += 1;
371 let tensor_size = gradients.len();
372 let start_time = std::time::Instant::now();
373
374 let strategy = if self.config.adaptive_communication {
376 if let Some(new_strategy) = self.adaptive_selector.evaluate_and_switch(
377 &self.perf_monitor,
378 tensor_size,
379 self.step_counter,
380 ) {
381 new_strategy
382 } else {
383 self.adaptive_selector.current_strategy
384 }
385 } else {
386 self.config.sync_strategy
387 };
388
389 let result = match strategy {
391 SyncStrategy::RingAllReduce => self.ring_allreduce(gradients),
392 SyncStrategy::TreeAllReduce => self.tree_allreduce(gradients),
393 SyncStrategy::HierarchicalAllReduce => self.hierarchical_allreduce(gradients),
394 SyncStrategy::PipelineParallel => {
395 if self.config.async_param_updates {
396 self.pipeline_parallel_async(gradients)
397 } else {
398 Err(GpuOptimError::UnsupportedOperation(
399 "Pipeline parallel requires async updates enabled".to_string(),
400 ))
401 }
402 }
403 };
404
405 let elapsed = start_time.elapsed();
407 let data_bytes = tensor_size * std::mem::size_of::<A>();
408
409 self.perf_monitor.record_communication(
410 strategy,
411 data_bytes as u64,
412 elapsed.as_micros() as u64,
413 );
414
415 if self
417 .step_counter
418 .is_multiple_of(self.config.bandwidth_monitor_interval)
419 {
420 self.log_performance_statistics();
421 }
422
423 result
424 }
425
426 fn ring_allreduce<S, D>(&self, gradients: &mut ArrayBase<S, D>) -> Result<(), GpuOptimError>
428 where
429 S: DataMut<Elem = A>,
430 D: Dimension,
431 {
432 #[cfg(any(
433 feature = "cuda",
434 feature = "metal",
435 feature = "opencl",
436 feature = "wgpu"
437 ))]
438 {
439 let kernel = self
440 .sync_kernels
441 .ring_allreduce
442 .as_ref()
443 .ok_or(GpuOptimError::NotInitialized)?;
444
445 let grad_len = gradients.len();
447
448 let grad_slice = gradients.as_slice_mut().ok_or_else(|| {
449 GpuOptimError::InvalidState("Gradients must be contiguous".to_string())
450 })?;
451
452 let grad_buffer = self.context.create_buffer_from_slice(grad_slice);
454
455 let chunk_size = grad_len.div_ceil(self.config.num_gpus);
457
458 kernel.set_buffer("data", &grad_buffer);
460 kernel.set_buffer(
461 "recv_buffer",
462 self.workspace.recv_buffer.as_ref().expect("unwrap failed"),
463 );
464 kernel.set_i32("chunk_size", chunk_size as i32);
465 kernel.set_i32("rank", self.config.rank as i32);
466 kernel.set_i32("world_size", self.config.num_gpus as i32);
467
468 for chunk_id in 0..self.config.num_gpus {
470 kernel.set_i32("chunk_id", chunk_id as i32);
471
472 let (grid_size, block_size) = crate::utils::calculate_block_size(chunk_size, 256);
473 kernel.dispatch([grid_size as u32, 1, 1]);
474 }
475
476 grad_buffer.copy_to_host(grad_slice)?;
478 }
479
480 Ok(())
481 }
482
483 fn tree_allreduce<S, D>(&self, gradients: &mut ArrayBase<S, D>) -> Result<(), GpuOptimError>
485 where
486 S: DataMut<Elem = A>,
487 D: Dimension,
488 {
489 #[cfg(any(
490 feature = "cuda",
491 feature = "metal",
492 feature = "opencl",
493 feature = "wgpu"
494 ))]
495 {
496 let kernel = self
497 .sync_kernels
498 .tree_allreduce
499 .as_ref()
500 .ok_or(GpuOptimError::NotInitialized)?;
501
502 let grad_len = gradients.len();
504
505 let grad_slice = gradients.as_slice_mut().ok_or_else(|| {
506 GpuOptimError::InvalidState("Gradients must be contiguous".to_string())
507 })?;
508
509 let grad_buffer = self.context.create_buffer_from_slice(grad_slice);
511
512 let num_levels = (self.config.num_gpus as f32).log2().ceil() as usize;
514
515 kernel.set_buffer("data", &grad_buffer);
517 kernel.set_buffer(
518 "workspace",
519 self.workspace.workspace.as_ref().expect("unwrap failed"),
520 );
521 kernel.set_i32("rank", self.config.rank as i32);
522 kernel.set_i32("world_size", self.config.num_gpus as i32);
523 kernel.set_i32("data_size", grad_len as i32);
524
525 for level in 0..num_levels {
527 let stride = 1 << level;
528 let peer_rank = self.config.rank ^ stride;
529
530 if peer_rank < self.config.num_gpus {
531 kernel.set_i32("level", level as i32);
532 kernel.set_i32("peer_rank", peer_rank as i32);
533
534 let (grid_size, block_size) = crate::utils::calculate_block_size(grad_len, 256);
535 kernel.dispatch([grid_size as u32, 1, 1]);
536
537 }
539 }
540
541 grad_buffer.copy_to_host(grad_slice)?;
543 }
544
545 Ok(())
546 }
547
548 fn hierarchical_allreduce<S, D>(
550 &self,
551 gradients: &mut ArrayBase<S, D>,
552 ) -> Result<(), GpuOptimError>
553 where
554 S: DataMut<Elem = A>,
555 D: Dimension,
556 {
557 #[cfg(any(
558 feature = "cuda",
559 feature = "metal",
560 feature = "opencl",
561 feature = "wgpu"
562 ))]
563 {
564 let kernel = self
565 .sync_kernels
566 .hierarchical_allreduce
567 .as_ref()
568 .ok_or(GpuOptimError::NotInitialized)?;
569
570 let grad_len = gradients.len();
572
573 let grad_slice = gradients.as_slice_mut().ok_or_else(|| {
574 GpuOptimError::InvalidState("Gradients must be contiguous".to_string())
575 })?;
576
577 let local_rank = self.config.rank % self.config.local_group_size;
579 let global_rank = self.config.rank / self.config.local_group_size;
580 let global_size = self.config.num_gpus / self.config.local_group_size;
581
582 let grad_buffer = self.context.create_buffer_from_slice(grad_slice);
584
585 kernel.set_buffer("data", &grad_buffer);
587 kernel.set_buffer(
588 "workspace",
589 self.workspace.workspace.as_ref().expect("unwrap failed"),
590 );
591 kernel.set_i32("local_rank", local_rank as i32);
592 kernel.set_i32("local_size", self.config.local_group_size as i32);
593 kernel.set_i32("global_rank", global_rank as i32);
594 kernel.set_i32("global_size", global_size as i32);
595 kernel.set_i32("data_size", grad_len as i32);
596 kernel.set_i32("phase", 1); let (grid_size, block_size) = crate::utils::calculate_block_size(grad_len, 256);
599 kernel.dispatch([grid_size as u32, 1, 1]);
600 if local_rank == 0 {
604 kernel.set_i32("phase", 2); kernel.dispatch([grid_size as u32, 1, 1]);
606 }
608
609 kernel.set_i32("phase", 3); kernel.dispatch([grid_size as u32, 1, 1]);
612 grad_buffer.copy_to_host(grad_slice)?;
616 }
617
618 Ok(())
619 }
620
621 fn pipeline_parallel_async<S, D>(
623 &mut self,
624 gradients: &mut ArrayBase<S, D>,
625 ) -> Result<(), GpuOptimError>
626 where
627 S: DataMut<Elem = A>,
628 D: Dimension,
629 {
630 #[cfg(any(
631 feature = "cuda",
632 feature = "metal",
633 feature = "opencl",
634 feature = "wgpu"
635 ))]
636 {
637 let grad_len = gradients.len();
639 let data_size = grad_len * std::mem::size_of::<A>();
640 let chunk_size = grad_len / self.config.pipeline_depth;
641
642 let grad_slice = gradients.as_slice_mut().ok_or_else(|| {
643 GpuOptimError::InvalidState("Gradients must be contiguous".to_string())
644 })?;
645
646 let handle = AsyncCommunicationHandle {
648 id: self.async_handles.len(),
649 start_time: std::time::Instant::now(),
650 expected_completion: std::time::Duration::from_millis(10), strategy: SyncStrategy::PipelineParallel,
652 data_size,
653 status: AsyncCommStatus::InProgress,
654 };
655
656 for stage in 0..self.config.pipeline_depth {
659 let start_idx = stage * chunk_size;
660 let end_idx = ((stage + 1) * chunk_size).min(grad_len);
661
662 if start_idx < end_idx {
663 let chunk_buffer = self
665 .context
666 .create_buffer_from_slice(&grad_slice[start_idx..end_idx]);
667
668 }
671 }
672
673 self.async_handles.push(handle);
674
675 if self.async_handles.len() > self.config.pipeline_depth * 2 {
677 self.cleanup_completed_handles();
678 }
679 }
680
681 Ok(())
682 }
683
684 fn log_performance_statistics(&self) {
686 let avg_bandwidth = self.perf_monitor.get_average_bandwidth();
687 let total_ops = self.perf_monitor.comm_operations;
688
689 println!(
690 "Multi-GPU Performance [Step {}]: {:.2} GB/s avg bandwidth, {} ops, current strategy: {:?}",
691 self.step_counter,
692 avg_bandwidth,
693 total_ops,
694 self.adaptive_selector.current_strategy
695 );
696 }
697
698 fn cleanup_completed_handles(&mut self) {
700 let current_time = std::time::Instant::now();
701
702 self.async_handles.retain(|handle| {
703 let elapsed = current_time.duration_since(handle.start_time);
704
705 if elapsed > handle.expected_completion {
706 false } else {
709 true }
711 });
712 }
713
714 pub fn get_performance_stats(&self) -> CommunicationPerformanceStats {
716 CommunicationPerformanceStats {
717 average_bandwidth_gb_s: self.perf_monitor.get_average_bandwidth(),
718 total_operations: self.perf_monitor.comm_operations,
719 total_data_transferred_gb: self.perf_monitor.total_data_bytes as f64 / 1e9,
720 current_strategy: self.adaptive_selector.current_strategy,
721 pending_async_ops: self.async_handles.len(),
722 step_count: self.step_counter,
723 }
724 }
725
726 pub fn synchronize_all(&mut self) -> Result<(), GpuOptimError> {
728 #[cfg(any(
729 feature = "cuda",
730 feature = "metal",
731 feature = "opencl",
732 feature = "wgpu"
733 ))]
734 {
735 for handle in &mut self.async_handles {
739 if handle.status == AsyncCommStatus::InProgress {
740 handle.status = AsyncCommStatus::Completed;
741 }
742 }
743
744 self.cleanup_completed_handles();
745 }
746
747 Ok(())
748 }
749
750 pub fn compress_gradients<S, D>(
752 &mut self,
753 gradients: &ArrayBase<S, D>,
754 ) -> Result<(Vec<A>, Vec<i32>), GpuOptimError>
755 where
756 S: Data<Elem = A>,
757 D: Dimension,
758 {
759 #[cfg(any(
760 feature = "cuda",
761 feature = "metal",
762 feature = "opencl",
763 feature = "wgpu"
764 ))]
765 {
766 let kernel = self
767 .sync_kernels
768 .compress_gradients
769 .as_ref()
770 .ok_or(GpuOptimError::NotInitialized)?;
771
772 let k = (gradients.len() as f32 * self.config.compression_ratio) as usize;
773
774 let compressed_values = vec![A::zero(); k];
779 let compressed_indices = vec![0i32; k];
780
781 Ok((compressed_values, compressed_indices))
782 }
783
784 #[cfg(not(any(
785 feature = "cuda",
786 feature = "metal",
787 feature = "opencl",
788 feature = "wgpu"
789 )))]
790 {
791 Err(GpuOptimError::UnsupportedOperation(
792 "GPU feature not enabled".to_string(),
793 ))
794 }
795 }
796
797 fn load_sync_kernels(
799 context: &Arc<GpuContext>,
800 config: &MultiGpuConfig,
801 ) -> Result<SyncKernels, GpuOptimError> {
802 #[cfg(any(
803 feature = "cuda",
804 feature = "metal",
805 feature = "opencl",
806 feature = "wgpu"
807 ))]
808 {
809 let ring_kernel = if matches!(config.sync_strategy, SyncStrategy::RingAllReduce) {
810 Some(Arc::new(context.get_kernel("ring_allreduce_f32")?))
811 } else {
812 None
813 };
814
815 let tree_kernel = if matches!(config.sync_strategy, SyncStrategy::TreeAllReduce) {
816 Some(Arc::new(context.get_kernel("tree_allreduce_f32")?))
817 } else {
818 None
819 };
820
821 let hierarchical_kernel =
822 if matches!(config.sync_strategy, SyncStrategy::HierarchicalAllReduce) {
823 Some(Arc::new(context.get_kernel("hierarchical_allreduce_f32")?))
824 } else {
825 None
826 };
827
828 let compress_kernel = if config.gradient_compression {
829 Some(Arc::new(context.get_kernel("compress_gradients_topk_f32")?))
830 } else {
831 None
832 };
833
834 let decompress_kernel = if config.gradient_compression {
835 Some(Arc::new(context.get_kernel("decompress_gradients_f32")?))
836 } else {
837 None
838 };
839
840 Ok(SyncKernels {
841 ring_allreduce: ring_kernel,
842 tree_allreduce: tree_kernel,
843 hierarchical_allreduce: hierarchical_kernel,
844 compress_gradients: compress_kernel,
845 decompress_gradients: decompress_kernel,
846 })
847 }
848
849 #[cfg(not(any(
850 feature = "cuda",
851 feature = "metal",
852 feature = "opencl",
853 feature = "wgpu"
854 )))]
855 {
856 Ok(SyncKernels {
857 ring_allreduce: None,
858 tree_allreduce: None,
859 hierarchical_allreduce: None,
860 compress_gradients: None,
861 decompress_gradients: None,
862 })
863 }
864 }
865
866 fn allocate_workspace(
868 context: &Arc<GpuContext>,
869 config: &MultiGpuConfig,
870 max_param_size: usize,
871 ) -> Result<WorkspaceBuffers<A>, GpuOptimError> {
872 #[cfg(any(
873 feature = "cuda",
874 feature = "metal",
875 feature = "opencl",
876 feature = "wgpu"
877 ))]
878 {
879 let recv_buffer = Some(context.create_buffer::<A>(max_param_size));
880 let workspace = Some(context.create_buffer::<A>(max_param_size));
881
882 let (compressed_values, compressed_indices, error_feedback) =
883 if config.gradient_compression {
884 let k = (max_param_size as f32 * config.compression_ratio) as usize;
885 (
886 Some(context.create_buffer::<A>(k)),
887 Some(context.create_buffer::<i32>(k)),
888 Some(context.create_buffer::<A>(max_param_size)),
889 )
890 } else {
891 (None, None, None)
892 };
893
894 Ok(WorkspaceBuffers {
895 recv_buffer,
896 workspace,
897 compressed_values,
898 compressed_indices,
899 error_feedback,
900 })
901 }
902
903 #[cfg(not(any(
904 feature = "cuda",
905 feature = "metal",
906 feature = "opencl",
907 feature = "wgpu"
908 )))]
909 {
910 Ok(WorkspaceBuffers {
911 recv_buffer: None,
912 workspace: None,
913 compressed_values: None,
914 compressed_indices: None,
915 error_feedback: None,
916 })
917 }
918 }
919}
920
921pub struct MultiGpuSetup {
923 pub contexts: Vec<Arc<GpuContext>>,
925 pub sync_managers: Vec<MultiGpuSync<f32>>,
927}
928
929impl MultiGpuSetup {
930 pub fn new(num_gpus: usize, max_param_size: usize) -> Result<Self, GpuOptimError> {
932 let mut contexts = Vec::new();
933 let mut sync_managers = Vec::new();
934
935 for rank in 0..num_gpus {
936 let context = Arc::new(GpuContext::new(scirs2_core::gpu::GpuBackend::Cuda)?);
938
939 let config = MultiGpuConfig {
941 num_gpus,
942 rank,
943 ..Default::default()
944 };
945
946 let sync_manager = MultiGpuSync::new(context.clone(), config, max_param_size)?;
947
948 contexts.push(context);
949 sync_managers.push(sync_manager);
950 }
951
952 Ok(Self {
953 contexts,
954 sync_managers,
955 })
956 }
957}
958
959#[cfg(test)]
960mod tests {
961 use super::*;
962
963 #[test]
964 fn test_multi_gpu_config_default() {
965 let config = MultiGpuConfig::default();
966 assert_eq!(config.num_gpus, 1);
967 assert_eq!(config.rank, 0);
968 assert_eq!(config.sync_strategy, SyncStrategy::RingAllReduce);
969 assert!(!config.gradient_compression);
970 }
971
972 #[test]
973 fn test_sync_strategy_selection() {
974 let strategies = [
975 SyncStrategy::RingAllReduce,
976 SyncStrategy::TreeAllReduce,
977 SyncStrategy::HierarchicalAllReduce,
978 SyncStrategy::PipelineParallel,
979 ];
980
981 for strategy in &strategies {
982 let config = MultiGpuConfig {
983 sync_strategy: *strategy,
984 ..Default::default()
985 };
986 assert_eq!(config.sync_strategy, *strategy);
987 }
988 }
989
990 #[test]
991 fn test_communication_performance_monitor() {
992 let mut monitor = CommunicationPerformanceMonitor::new();
993
994 monitor.record_communication(SyncStrategy::RingAllReduce, 1000000, 1000); monitor.record_communication(SyncStrategy::TreeAllReduce, 2000000, 1000); assert_eq!(monitor.comm_operations, 2);
999 assert!(monitor.get_average_bandwidth() > 0.0);
1000
1001 let optimal = monitor.get_optimal_strategy(1000000);
1003 assert!(matches!(
1004 optimal,
1005 SyncStrategy::RingAllReduce | SyncStrategy::TreeAllReduce
1006 ));
1007 }
1008
1009 #[test]
1010 fn test_adaptive_communication_selector() {
1011 let mut selector = AdaptiveCommunicationSelector::new();
1012 let mut monitor = CommunicationPerformanceMonitor::new();
1013
1014 assert_eq!(selector.current_strategy, SyncStrategy::RingAllReduce);
1016
1017 for _ in 0..10 {
1019 monitor.record_communication(SyncStrategy::TreeAllReduce, 1000000, 500);
1020 }
1022
1023 let new_strategy = selector.evaluate_and_switch(&monitor, 1000000, 100);
1025
1026 if let Some(strategy) = new_strategy {
1028 assert_ne!(strategy, SyncStrategy::RingAllReduce);
1029 }
1030 }
1031
1032 #[test]
1033 fn test_multi_gpu_config_extended() {
1034 let config = MultiGpuConfig {
1035 num_gpus: 8,
1036 adaptive_communication: true,
1037 bandwidth_monitor_interval: 50,
1038 async_param_updates: true,
1039 communication_timeout_ms: 1000,
1040 error_correction: true,
1041 pipeline_depth: 4,
1042 ..Default::default()
1043 };
1044
1045 assert_eq!(config.num_gpus, 8);
1046 assert!(config.adaptive_communication);
1047 assert_eq!(config.bandwidth_monitor_interval, 50);
1048 assert!(config.async_param_updates);
1049 assert_eq!(config.communication_timeout_ms, 1000);
1050 assert!(config.error_correction);
1051 assert_eq!(config.pipeline_depth, 4);
1052 }
1053
1054 #[test]
1055 fn test_async_communication_handle() {
1056 let handle = AsyncCommunicationHandle {
1057 id: 0,
1058 start_time: std::time::Instant::now(),
1059 expected_completion: std::time::Duration::from_millis(10),
1060 strategy: SyncStrategy::PipelineParallel,
1061 data_size: 1000000,
1062 status: AsyncCommStatus::Pending,
1063 };
1064
1065 assert_eq!(handle.id, 0);
1066 assert_eq!(handle.strategy, SyncStrategy::PipelineParallel);
1067 assert_eq!(handle.data_size, 1000000);
1068 assert_eq!(handle.status, AsyncCommStatus::Pending);
1069 }
1070
1071 #[test]
1072 fn test_strategy_performance_metrics() {
1073 let mut metrics = StrategyPerformanceMetrics::new();
1074
1075 metrics.update(10.0, 1000); metrics.update(15.0, 800); assert!(metrics.efficiency_score > 0.0);
1079
1080 let score = metrics.calculate_score(1000000); assert!(score > 0.0);
1082 }
1083
1084 #[test]
1085 fn test_communication_performance_stats() {
1086 let stats = CommunicationPerformanceStats {
1087 average_bandwidth_gb_s: 10.5,
1088 total_operations: 100,
1089 total_data_transferred_gb: 50.0,
1090 current_strategy: SyncStrategy::RingAllReduce,
1091 pending_async_ops: 2,
1092 step_count: 1000,
1093 };
1094
1095 assert_eq!(stats.average_bandwidth_gb_s, 10.5);
1096 assert_eq!(stats.total_operations, 100);
1097 assert_eq!(stats.total_data_transferred_gb, 50.0);
1098 assert_eq!(stats.current_strategy, SyncStrategy::RingAllReduce);
1099 assert_eq!(stats.pending_async_ops, 2);
1100 assert_eq!(stats.step_count, 1000);
1101 }
1102}