1use scirs2_core::gpu::{GpuBuffer, GpuContext, GpuDataType, GpuKernelHandle};
4use scirs2_core::ndarray::{ArrayBase, Data, DataMut, Dimension};
5use scirs2_core::numeric::Float;
6use std::marker::PhantomData;
7use std::sync::Arc;
8
9use crate::backends::GpuBackend;
10use crate::GpuOptimError;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
14pub enum SyncStrategy {
15 RingAllReduce,
17 TreeAllReduce,
19 HierarchicalAllReduce,
21 PipelineParallel,
23}
24
25#[derive(Debug, Clone)]
27pub struct MultiGpuConfig {
28 pub num_gpus: usize,
30 pub rank: usize,
32 pub sync_strategy: SyncStrategy,
34 pub gradient_compression: bool,
36 pub compression_ratio: f32,
38 pub local_group_size: usize,
40 pub adaptive_communication: bool,
42 pub bandwidth_monitor_interval: usize,
44 pub async_param_updates: bool,
46 pub communication_timeout_ms: u64,
48 pub error_correction: bool,
50 pub pipeline_depth: usize,
52}
53
54impl Default for MultiGpuConfig {
55 fn default() -> Self {
56 Self {
57 num_gpus: 1,
58 rank: 0,
59 sync_strategy: SyncStrategy::RingAllReduce,
60 gradient_compression: false,
61 compression_ratio: 0.1, local_group_size: 4,
63 adaptive_communication: true,
64 bandwidth_monitor_interval: 100,
65 async_param_updates: false,
66 communication_timeout_ms: 5000,
67 error_correction: true,
68 pipeline_depth: 2,
69 }
70 }
71}
72
73pub struct MultiGpuSync<A: Float + GpuDataType> {
75 context: Arc<GpuContext>,
77 config: MultiGpuConfig,
79 sync_kernels: SyncKernels,
81 workspace: WorkspaceBuffers<A>,
83 perf_monitor: CommunicationPerformanceMonitor,
85 adaptive_selector: AdaptiveCommunicationSelector,
87 async_handles: Vec<AsyncCommunicationHandle>,
89 step_counter: usize,
91 _phantom: PhantomData<A>,
93}
94
95struct SyncKernels {
97 ring_allreduce: Option<Arc<GpuKernelHandle>>,
98 tree_allreduce: Option<Arc<GpuKernelHandle>>,
99 hierarchical_allreduce: Option<Arc<GpuKernelHandle>>,
100 compress_gradients: Option<Arc<GpuKernelHandle>>,
101 decompress_gradients: Option<Arc<GpuKernelHandle>>,
102}
103
104struct WorkspaceBuffers<A: Float + GpuDataType> {
106 recv_buffer: Option<GpuBuffer<A>>,
107 workspace: Option<GpuBuffer<A>>,
108 compressed_values: Option<GpuBuffer<A>>,
109 compressed_indices: Option<GpuBuffer<i32>>,
110 error_feedback: Option<GpuBuffer<A>>,
111}
112
113#[derive(Debug, Clone)]
115pub struct CommunicationPerformanceMonitor {
116 total_comm_time_us: u64,
118 total_data_bytes: u64,
120 comm_operations: usize,
122 bandwidth_history: std::collections::VecDeque<f64>,
124 strategy_performance: std::collections::HashMap<SyncStrategy, StrategyPerformanceMetrics>,
126 optimal_strategy: SyncStrategy,
128}
129
130impl CommunicationPerformanceMonitor {
131 fn new() -> Self {
132 Self {
133 total_comm_time_us: 0,
134 total_data_bytes: 0,
135 comm_operations: 0,
136 bandwidth_history: std::collections::VecDeque::with_capacity(1000),
137 strategy_performance: std::collections::HashMap::new(),
138 optimal_strategy: SyncStrategy::RingAllReduce,
139 }
140 }
141
142 fn record_communication(&mut self, strategy: SyncStrategy, data_bytes: u64, timeus: u64) {
143 self.total_comm_time_us += timeus;
144 self.total_data_bytes += data_bytes;
145 self.comm_operations += 1;
146
147 let bandwidth_gb_s = (data_bytes as f64) / (timeus as f64 / 1_000_000.0) / 1e9;
148 self.bandwidth_history.push_back(bandwidth_gb_s);
149
150 if self.bandwidth_history.len() > 1000 {
151 self.bandwidth_history.pop_front();
152 }
153
154 let metrics = self
156 .strategy_performance
157 .entry(strategy)
158 .or_insert_with(StrategyPerformanceMetrics::new);
159 metrics.update(bandwidth_gb_s, timeus);
160 }
161
162 fn get_average_bandwidth(&self) -> f64 {
163 if self.total_comm_time_us == 0 {
164 0.0
165 } else {
166 (self.total_data_bytes as f64) / (self.total_comm_time_us as f64 / 1_000_000.0) / 1e9
167 }
168 }
169
170 fn get_optimal_strategy(&self, tensorsize: usize) -> SyncStrategy {
171 let mut best_strategy = SyncStrategy::RingAllReduce;
172 let mut best_score = 0.0;
173
174 for (strategy, metrics) in &self.strategy_performance {
175 let score = metrics.calculate_score(tensorsize);
176 if score > best_score {
177 best_score = score;
178 best_strategy = *strategy;
179 }
180 }
181
182 best_strategy
183 }
184}
185
186#[derive(Debug, Clone)]
188struct StrategyPerformanceMetrics {
189 bandwidth_samples: std::collections::VecDeque<f64>,
190 latency_samples: std::collections::VecDeque<u64>,
191 tensor_sizes: std::collections::VecDeque<usize>,
192 efficiency_score: f64,
193}
194
195impl StrategyPerformanceMetrics {
196 fn new() -> Self {
197 Self {
198 bandwidth_samples: std::collections::VecDeque::with_capacity(100),
199 latency_samples: std::collections::VecDeque::with_capacity(100),
200 tensor_sizes: std::collections::VecDeque::with_capacity(100),
201 efficiency_score: 0.0,
202 }
203 }
204
205 fn update(&mut self, bandwidth_gb_s: f64, latencyus: u64) {
206 self.bandwidth_samples.push_back(bandwidth_gb_s);
207 self.latency_samples.push_back(latencyus);
208
209 if self.bandwidth_samples.len() > 100 {
210 self.bandwidth_samples.pop_front();
211 self.latency_samples.pop_front();
212 }
213
214 let avg_bandwidth =
216 self.bandwidth_samples.iter().sum::<f64>() / self.bandwidth_samples.len() as f64;
217 let avg_latency =
218 self.latency_samples.iter().sum::<u64>() as f64 / self.latency_samples.len() as f64;
219
220 self.efficiency_score = avg_bandwidth / (avg_latency / 1000.0); }
222
223 fn calculate_score(&self, tensorsize: usize) -> f64 {
224 let size_factor = if tensorsize > 1000000 { 2.0 } else { 1.0 }; self.efficiency_score * size_factor
227 }
228}
229
230#[derive(Debug)]
232pub struct AdaptiveCommunicationSelector {
233 current_strategy: SyncStrategy,
235 switch_cooldown: usize,
237 last_switch_step: usize,
239 evaluation_window: usize,
241 performance_threshold: f64,
243}
244
245impl AdaptiveCommunicationSelector {
246 fn new() -> Self {
247 Self {
248 current_strategy: SyncStrategy::RingAllReduce,
249 switch_cooldown: 50,
250 last_switch_step: 0,
251 evaluation_window: 20,
252 performance_threshold: 1.2, }
254 }
255
256 fn should_evaluate_strategy(&self, currentstep: usize) -> bool {
257 currentstep - self.last_switch_step >= self.switch_cooldown
258 }
259
260 fn evaluate_and_switch(
261 &mut self,
262 monitor: &CommunicationPerformanceMonitor,
263 tensor_size: usize,
264 current_step: usize,
265 ) -> Option<SyncStrategy> {
266 if !self.should_evaluate_strategy(current_step) {
267 return None;
268 }
269
270 let optimal_strategy = monitor.get_optimal_strategy(tensor_size);
271
272 if optimal_strategy != self.current_strategy {
273 if let (Some(current_metrics), Some(optimal_metrics)) = (
275 monitor.strategy_performance.get(&self.current_strategy),
276 monitor.strategy_performance.get(&optimal_strategy),
277 ) {
278 let performance_ratio =
279 optimal_metrics.efficiency_score / current_metrics.efficiency_score;
280
281 if performance_ratio >= self.performance_threshold {
282 self.current_strategy = optimal_strategy;
283 self.last_switch_step = current_step;
284 return Some(optimal_strategy);
285 }
286 }
287 }
288
289 None
290 }
291}
292
293#[derive(Debug)]
295pub struct AsyncCommunicationHandle {
296 id: usize,
298 start_time: std::time::Instant,
300 expected_completion: std::time::Duration,
302 strategy: SyncStrategy,
304 data_size: usize,
306 status: AsyncCommStatus,
308}
309
310#[derive(Debug, Clone, Copy, PartialEq)]
311pub enum AsyncCommStatus {
312 Pending,
313 InProgress,
314 Completed,
315 Failed,
316 Timeout,
317}
318
319#[derive(Debug, Clone)]
321pub struct CommunicationPerformanceStats {
322 pub average_bandwidth_gb_s: f64,
323 pub total_operations: usize,
324 pub total_data_transferred_gb: f64,
325 pub current_strategy: SyncStrategy,
326 pub pending_async_ops: usize,
327 pub step_count: usize,
328}
329
330impl<A: Float + GpuDataType + Send + Sync> MultiGpuSync<A> {
331 pub fn new(
333 context: Arc<GpuContext>,
334 config: MultiGpuConfig,
335 max_param_size: usize,
336 ) -> Result<Self, GpuOptimError> {
337 let sync_kernels = Self::load_sync_kernels(&context, &config)?;
339
340 let workspace = Self::allocate_workspace(&context, &config, max_param_size)?;
342
343 let perf_monitor = CommunicationPerformanceMonitor::new();
345 let adaptive_selector = AdaptiveCommunicationSelector::new();
346 let async_handles = Vec::with_capacity(config.pipeline_depth);
347
348 Ok(Self {
349 context,
350 config,
351 sync_kernels,
352 workspace,
353 perf_monitor,
354 adaptive_selector,
355 async_handles,
356 step_counter: 0,
357 _phantom: PhantomData,
358 })
359 }
360
361 pub fn sync_gradients<S, D>(
363 &mut self,
364 gradients: &mut ArrayBase<S, D>,
365 ) -> Result<(), GpuOptimError>
366 where
367 S: DataMut<Elem = A>,
368 D: Dimension,
369 {
370 self.step_counter += 1;
371 let tensor_size = gradients.len();
372 let start_time = std::time::Instant::now();
373
374 let strategy = if self.config.adaptive_communication {
376 if let Some(new_strategy) = self.adaptive_selector.evaluate_and_switch(
377 &self.perf_monitor,
378 tensor_size,
379 self.step_counter,
380 ) {
381 new_strategy
382 } else {
383 self.adaptive_selector.current_strategy
384 }
385 } else {
386 self.config.sync_strategy
387 };
388
389 let result = match strategy {
391 SyncStrategy::RingAllReduce => self.ring_allreduce(gradients),
392 SyncStrategy::TreeAllReduce => self.tree_allreduce(gradients),
393 SyncStrategy::HierarchicalAllReduce => self.hierarchical_allreduce(gradients),
394 SyncStrategy::PipelineParallel => {
395 if self.config.async_param_updates {
396 self.pipeline_parallel_async(gradients)
397 } else {
398 Err(GpuOptimError::UnsupportedOperation(
399 "Pipeline parallel requires async updates enabled".to_string(),
400 ))
401 }
402 }
403 };
404
405 let elapsed = start_time.elapsed();
407 let data_bytes = tensor_size * std::mem::size_of::<A>();
408
409 self.perf_monitor.record_communication(
410 strategy,
411 data_bytes as u64,
412 elapsed.as_micros() as u64,
413 );
414
415 if self
417 .step_counter
418 .is_multiple_of(self.config.bandwidth_monitor_interval)
419 {
420 self.log_performance_statistics();
421 }
422
423 result
424 }
425
426 fn ring_allreduce<S, D>(&self, gradients: &mut ArrayBase<S, D>) -> Result<(), GpuOptimError>
428 where
429 S: DataMut<Elem = A>,
430 D: Dimension,
431 {
432 #[cfg(any(
433 feature = "cuda",
434 feature = "metal",
435 feature = "opencl",
436 feature = "wgpu"
437 ))]
438 {
439 let kernel = self
440 .sync_kernels
441 .ring_allreduce
442 .as_ref()
443 .ok_or(GpuOptimError::NotInitialized)?;
444
445 let grad_len = gradients.len();
447
448 let grad_slice = gradients.as_slice_mut().ok_or_else(|| {
449 GpuOptimError::InvalidState("Gradients must be contiguous".to_string())
450 })?;
451
452 let grad_buffer = self.context.create_buffer_from_slice(grad_slice);
454
455 let chunk_size = grad_len.div_ceil(self.config.num_gpus);
457
458 kernel.set_buffer("data", &grad_buffer);
460 kernel.set_buffer("recv_buffer", self.workspace.recv_buffer.as_ref().unwrap());
461 kernel.set_i32("chunk_size", chunk_size as i32);
462 kernel.set_i32("rank", self.config.rank as i32);
463 kernel.set_i32("world_size", self.config.num_gpus as i32);
464
465 for chunk_id in 0..self.config.num_gpus {
467 kernel.set_i32("chunk_id", chunk_id as i32);
468
469 let (grid_size, block_size) = crate::utils::calculate_block_size(chunk_size, 256);
470 kernel.dispatch([grid_size as u32, 1, 1]);
471 }
472
473 grad_buffer.copy_to_host(grad_slice)?;
475 }
476
477 Ok(())
478 }
479
480 fn tree_allreduce<S, D>(&self, gradients: &mut ArrayBase<S, D>) -> Result<(), GpuOptimError>
482 where
483 S: DataMut<Elem = A>,
484 D: Dimension,
485 {
486 #[cfg(any(
487 feature = "cuda",
488 feature = "metal",
489 feature = "opencl",
490 feature = "wgpu"
491 ))]
492 {
493 let kernel = self
494 .sync_kernels
495 .tree_allreduce
496 .as_ref()
497 .ok_or(GpuOptimError::NotInitialized)?;
498
499 let grad_len = gradients.len();
501
502 let grad_slice = gradients.as_slice_mut().ok_or_else(|| {
503 GpuOptimError::InvalidState("Gradients must be contiguous".to_string())
504 })?;
505
506 let grad_buffer = self.context.create_buffer_from_slice(grad_slice);
508
509 let num_levels = (self.config.num_gpus as f32).log2().ceil() as usize;
511
512 kernel.set_buffer("data", &grad_buffer);
514 kernel.set_buffer("workspace", self.workspace.workspace.as_ref().unwrap());
515 kernel.set_i32("rank", self.config.rank as i32);
516 kernel.set_i32("world_size", self.config.num_gpus as i32);
517 kernel.set_i32("data_size", grad_len as i32);
518
519 for level in 0..num_levels {
521 let stride = 1 << level;
522 let peer_rank = self.config.rank ^ stride;
523
524 if peer_rank < self.config.num_gpus {
525 kernel.set_i32("level", level as i32);
526 kernel.set_i32("peer_rank", peer_rank as i32);
527
528 let (grid_size, block_size) = crate::utils::calculate_block_size(grad_len, 256);
529 kernel.dispatch([grid_size as u32, 1, 1]);
530
531 }
533 }
534
535 grad_buffer.copy_to_host(grad_slice)?;
537 }
538
539 Ok(())
540 }
541
542 fn hierarchical_allreduce<S, D>(
544 &self,
545 gradients: &mut ArrayBase<S, D>,
546 ) -> Result<(), GpuOptimError>
547 where
548 S: DataMut<Elem = A>,
549 D: Dimension,
550 {
551 #[cfg(any(
552 feature = "cuda",
553 feature = "metal",
554 feature = "opencl",
555 feature = "wgpu"
556 ))]
557 {
558 let kernel = self
559 .sync_kernels
560 .hierarchical_allreduce
561 .as_ref()
562 .ok_or(GpuOptimError::NotInitialized)?;
563
564 let grad_len = gradients.len();
566
567 let grad_slice = gradients.as_slice_mut().ok_or_else(|| {
568 GpuOptimError::InvalidState("Gradients must be contiguous".to_string())
569 })?;
570
571 let local_rank = self.config.rank % self.config.local_group_size;
573 let global_rank = self.config.rank / self.config.local_group_size;
574 let global_size = self.config.num_gpus / self.config.local_group_size;
575
576 let grad_buffer = self.context.create_buffer_from_slice(grad_slice);
578
579 kernel.set_buffer("data", &grad_buffer);
581 kernel.set_buffer("workspace", self.workspace.workspace.as_ref().unwrap());
582 kernel.set_i32("local_rank", local_rank as i32);
583 kernel.set_i32("local_size", self.config.local_group_size as i32);
584 kernel.set_i32("global_rank", global_rank as i32);
585 kernel.set_i32("global_size", global_size as i32);
586 kernel.set_i32("data_size", grad_len as i32);
587 kernel.set_i32("phase", 1); let (grid_size, block_size) = crate::utils::calculate_block_size(grad_len, 256);
590 kernel.dispatch([grid_size as u32, 1, 1]);
591 if local_rank == 0 {
595 kernel.set_i32("phase", 2); kernel.dispatch([grid_size as u32, 1, 1]);
597 }
599
600 kernel.set_i32("phase", 3); kernel.dispatch([grid_size as u32, 1, 1]);
603 grad_buffer.copy_to_host(grad_slice)?;
607 }
608
609 Ok(())
610 }
611
612 fn pipeline_parallel_async<S, D>(
614 &mut self,
615 gradients: &mut ArrayBase<S, D>,
616 ) -> Result<(), GpuOptimError>
617 where
618 S: DataMut<Elem = A>,
619 D: Dimension,
620 {
621 #[cfg(any(
622 feature = "cuda",
623 feature = "metal",
624 feature = "opencl",
625 feature = "wgpu"
626 ))]
627 {
628 let grad_len = gradients.len();
630 let data_size = grad_len * std::mem::size_of::<A>();
631 let chunk_size = grad_len / self.config.pipeline_depth;
632
633 let grad_slice = gradients.as_slice_mut().ok_or_else(|| {
634 GpuOptimError::InvalidState("Gradients must be contiguous".to_string())
635 })?;
636
637 let handle = AsyncCommunicationHandle {
639 id: self.async_handles.len(),
640 start_time: std::time::Instant::now(),
641 expected_completion: std::time::Duration::from_millis(10), strategy: SyncStrategy::PipelineParallel,
643 data_size,
644 status: AsyncCommStatus::InProgress,
645 };
646
647 for stage in 0..self.config.pipeline_depth {
650 let start_idx = stage * chunk_size;
651 let end_idx = ((stage + 1) * chunk_size).min(grad_len);
652
653 if start_idx < end_idx {
654 let chunk_buffer = self
656 .context
657 .create_buffer_from_slice(&grad_slice[start_idx..end_idx]);
658
659 }
662 }
663
664 self.async_handles.push(handle);
665
666 if self.async_handles.len() > self.config.pipeline_depth * 2 {
668 self.cleanup_completed_handles();
669 }
670 }
671
672 Ok(())
673 }
674
675 fn log_performance_statistics(&self) {
677 let avg_bandwidth = self.perf_monitor.get_average_bandwidth();
678 let total_ops = self.perf_monitor.comm_operations;
679
680 println!(
681 "Multi-GPU Performance [Step {}]: {:.2} GB/s avg bandwidth, {} ops, current strategy: {:?}",
682 self.step_counter,
683 avg_bandwidth,
684 total_ops,
685 self.adaptive_selector.current_strategy
686 );
687 }
688
689 fn cleanup_completed_handles(&mut self) {
691 let current_time = std::time::Instant::now();
692
693 self.async_handles.retain(|handle| {
694 let elapsed = current_time.duration_since(handle.start_time);
695
696 if elapsed > handle.expected_completion {
697 false } else {
700 true }
702 });
703 }
704
705 pub fn get_performance_stats(&self) -> CommunicationPerformanceStats {
707 CommunicationPerformanceStats {
708 average_bandwidth_gb_s: self.perf_monitor.get_average_bandwidth(),
709 total_operations: self.perf_monitor.comm_operations,
710 total_data_transferred_gb: self.perf_monitor.total_data_bytes as f64 / 1e9,
711 current_strategy: self.adaptive_selector.current_strategy,
712 pending_async_ops: self.async_handles.len(),
713 step_count: self.step_counter,
714 }
715 }
716
717 pub fn synchronize_all(&mut self) -> Result<(), GpuOptimError> {
719 #[cfg(any(
720 feature = "cuda",
721 feature = "metal",
722 feature = "opencl",
723 feature = "wgpu"
724 ))]
725 {
726 for handle in &mut self.async_handles {
730 if handle.status == AsyncCommStatus::InProgress {
731 handle.status = AsyncCommStatus::Completed;
732 }
733 }
734
735 self.cleanup_completed_handles();
736 }
737
738 Ok(())
739 }
740
741 pub fn compress_gradients<S, D>(
743 &mut self,
744 gradients: &ArrayBase<S, D>,
745 ) -> Result<(Vec<A>, Vec<i32>), GpuOptimError>
746 where
747 S: Data<Elem = A>,
748 D: Dimension,
749 {
750 #[cfg(any(
751 feature = "cuda",
752 feature = "metal",
753 feature = "opencl",
754 feature = "wgpu"
755 ))]
756 {
757 let kernel = self
758 .sync_kernels
759 .compress_gradients
760 .as_ref()
761 .ok_or(GpuOptimError::NotInitialized)?;
762
763 let k = (gradients.len() as f32 * self.config.compression_ratio) as usize;
764
765 let compressed_values = vec![A::zero(); k];
770 let compressed_indices = vec![0i32; k];
771
772 Ok((compressed_values, compressed_indices))
773 }
774
775 #[cfg(not(any(
776 feature = "cuda",
777 feature = "metal",
778 feature = "opencl",
779 feature = "wgpu"
780 )))]
781 {
782 Err(GpuOptimError::UnsupportedOperation(
783 "GPU feature not enabled".to_string(),
784 ))
785 }
786 }
787
788 fn load_sync_kernels(
790 context: &Arc<GpuContext>,
791 config: &MultiGpuConfig,
792 ) -> Result<SyncKernels, GpuOptimError> {
793 #[cfg(any(
794 feature = "cuda",
795 feature = "metal",
796 feature = "opencl",
797 feature = "wgpu"
798 ))]
799 {
800 let ring_kernel = if matches!(config.sync_strategy, SyncStrategy::RingAllReduce) {
801 Some(Arc::new(context.get_kernel("ring_allreduce_f32")?))
802 } else {
803 None
804 };
805
806 let tree_kernel = if matches!(config.sync_strategy, SyncStrategy::TreeAllReduce) {
807 Some(Arc::new(context.get_kernel("tree_allreduce_f32")?))
808 } else {
809 None
810 };
811
812 let hierarchical_kernel =
813 if matches!(config.sync_strategy, SyncStrategy::HierarchicalAllReduce) {
814 Some(Arc::new(context.get_kernel("hierarchical_allreduce_f32")?))
815 } else {
816 None
817 };
818
819 let compress_kernel = if config.gradient_compression {
820 Some(Arc::new(context.get_kernel("compress_gradients_topk_f32")?))
821 } else {
822 None
823 };
824
825 let decompress_kernel = if config.gradient_compression {
826 Some(Arc::new(context.get_kernel("decompress_gradients_f32")?))
827 } else {
828 None
829 };
830
831 Ok(SyncKernels {
832 ring_allreduce: ring_kernel,
833 tree_allreduce: tree_kernel,
834 hierarchical_allreduce: hierarchical_kernel,
835 compress_gradients: compress_kernel,
836 decompress_gradients: decompress_kernel,
837 })
838 }
839
840 #[cfg(not(any(
841 feature = "cuda",
842 feature = "metal",
843 feature = "opencl",
844 feature = "wgpu"
845 )))]
846 {
847 Ok(SyncKernels {
848 ring_allreduce: None,
849 tree_allreduce: None,
850 hierarchical_allreduce: None,
851 compress_gradients: None,
852 decompress_gradients: None,
853 })
854 }
855 }
856
857 fn allocate_workspace(
859 context: &Arc<GpuContext>,
860 config: &MultiGpuConfig,
861 max_param_size: usize,
862 ) -> Result<WorkspaceBuffers<A>, GpuOptimError> {
863 #[cfg(any(
864 feature = "cuda",
865 feature = "metal",
866 feature = "opencl",
867 feature = "wgpu"
868 ))]
869 {
870 let recv_buffer = Some(context.create_buffer::<A>(max_param_size));
871 let workspace = Some(context.create_buffer::<A>(max_param_size));
872
873 let (compressed_values, compressed_indices, error_feedback) =
874 if config.gradient_compression {
875 let k = (max_param_size as f32 * config.compression_ratio) as usize;
876 (
877 Some(context.create_buffer::<A>(k)),
878 Some(context.create_buffer::<i32>(k)),
879 Some(context.create_buffer::<A>(max_param_size)),
880 )
881 } else {
882 (None, None, None)
883 };
884
885 Ok(WorkspaceBuffers {
886 recv_buffer,
887 workspace,
888 compressed_values,
889 compressed_indices,
890 error_feedback,
891 })
892 }
893
894 #[cfg(not(any(
895 feature = "cuda",
896 feature = "metal",
897 feature = "opencl",
898 feature = "wgpu"
899 )))]
900 {
901 Ok(WorkspaceBuffers {
902 recv_buffer: None,
903 workspace: None,
904 compressed_values: None,
905 compressed_indices: None,
906 error_feedback: None,
907 })
908 }
909 }
910}
911
912pub struct MultiGpuSetup {
914 pub contexts: Vec<Arc<GpuContext>>,
916 pub sync_managers: Vec<MultiGpuSync<f32>>,
918}
919
920impl MultiGpuSetup {
921 pub fn new(num_gpus: usize, max_param_size: usize) -> Result<Self, GpuOptimError> {
923 let mut contexts = Vec::new();
924 let mut sync_managers = Vec::new();
925
926 for rank in 0..num_gpus {
927 let context = Arc::new(GpuContext::new(scirs2_core::gpu::GpuBackend::Cuda)?);
929
930 let config = MultiGpuConfig {
932 num_gpus,
933 rank,
934 ..Default::default()
935 };
936
937 let sync_manager = MultiGpuSync::new(context.clone(), config, max_param_size)?;
938
939 contexts.push(context);
940 sync_managers.push(sync_manager);
941 }
942
943 Ok(Self {
944 contexts,
945 sync_managers,
946 })
947 }
948}
949
950#[cfg(test)]
951mod tests {
952 use super::*;
953
954 #[test]
955 fn test_multi_gpu_config_default() {
956 let config = MultiGpuConfig::default();
957 assert_eq!(config.num_gpus, 1);
958 assert_eq!(config.rank, 0);
959 assert_eq!(config.sync_strategy, SyncStrategy::RingAllReduce);
960 assert!(!config.gradient_compression);
961 }
962
963 #[test]
964 fn test_sync_strategy_selection() {
965 let strategies = [
966 SyncStrategy::RingAllReduce,
967 SyncStrategy::TreeAllReduce,
968 SyncStrategy::HierarchicalAllReduce,
969 SyncStrategy::PipelineParallel,
970 ];
971
972 for strategy in &strategies {
973 let config = MultiGpuConfig {
974 sync_strategy: *strategy,
975 ..Default::default()
976 };
977 assert_eq!(config.sync_strategy, *strategy);
978 }
979 }
980
981 #[test]
982 fn test_communication_performance_monitor() {
983 let mut monitor = CommunicationPerformanceMonitor::new();
984
985 monitor.record_communication(SyncStrategy::RingAllReduce, 1000000, 1000); monitor.record_communication(SyncStrategy::TreeAllReduce, 2000000, 1000); assert_eq!(monitor.comm_operations, 2);
990 assert!(monitor.get_average_bandwidth() > 0.0);
991
992 let optimal = monitor.get_optimal_strategy(1000000);
994 assert!(matches!(
995 optimal,
996 SyncStrategy::RingAllReduce | SyncStrategy::TreeAllReduce
997 ));
998 }
999
1000 #[test]
1001 fn test_adaptive_communication_selector() {
1002 let mut selector = AdaptiveCommunicationSelector::new();
1003 let mut monitor = CommunicationPerformanceMonitor::new();
1004
1005 assert_eq!(selector.current_strategy, SyncStrategy::RingAllReduce);
1007
1008 for _ in 0..10 {
1010 monitor.record_communication(SyncStrategy::TreeAllReduce, 1000000, 500);
1011 }
1013
1014 let new_strategy = selector.evaluate_and_switch(&monitor, 1000000, 100);
1016
1017 if let Some(strategy) = new_strategy {
1019 assert_ne!(strategy, SyncStrategy::RingAllReduce);
1020 }
1021 }
1022
1023 #[test]
1024 fn test_multi_gpu_config_extended() {
1025 let config = MultiGpuConfig {
1026 num_gpus: 8,
1027 adaptive_communication: true,
1028 bandwidth_monitor_interval: 50,
1029 async_param_updates: true,
1030 communication_timeout_ms: 1000,
1031 error_correction: true,
1032 pipeline_depth: 4,
1033 ..Default::default()
1034 };
1035
1036 assert_eq!(config.num_gpus, 8);
1037 assert!(config.adaptive_communication);
1038 assert_eq!(config.bandwidth_monitor_interval, 50);
1039 assert!(config.async_param_updates);
1040 assert_eq!(config.communication_timeout_ms, 1000);
1041 assert!(config.error_correction);
1042 assert_eq!(config.pipeline_depth, 4);
1043 }
1044
1045 #[test]
1046 fn test_async_communication_handle() {
1047 let handle = AsyncCommunicationHandle {
1048 id: 0,
1049 start_time: std::time::Instant::now(),
1050 expected_completion: std::time::Duration::from_millis(10),
1051 strategy: SyncStrategy::PipelineParallel,
1052 data_size: 1000000,
1053 status: AsyncCommStatus::Pending,
1054 };
1055
1056 assert_eq!(handle.id, 0);
1057 assert_eq!(handle.strategy, SyncStrategy::PipelineParallel);
1058 assert_eq!(handle.data_size, 1000000);
1059 assert_eq!(handle.status, AsyncCommStatus::Pending);
1060 }
1061
1062 #[test]
1063 fn test_strategy_performance_metrics() {
1064 let mut metrics = StrategyPerformanceMetrics::new();
1065
1066 metrics.update(10.0, 1000); metrics.update(15.0, 800); assert!(metrics.efficiency_score > 0.0);
1070
1071 let score = metrics.calculate_score(1000000); assert!(score > 0.0);
1073 }
1074
1075 #[test]
1076 fn test_communication_performance_stats() {
1077 let stats = CommunicationPerformanceStats {
1078 average_bandwidth_gb_s: 10.5,
1079 total_operations: 100,
1080 total_data_transferred_gb: 50.0,
1081 current_strategy: SyncStrategy::RingAllReduce,
1082 pending_async_ops: 2,
1083 step_count: 1000,
1084 };
1085
1086 assert_eq!(stats.average_bandwidth_gb_s, 10.5);
1087 assert_eq!(stats.total_operations, 100);
1088 assert_eq!(stats.total_data_transferred_gb, 50.0);
1089 assert_eq!(stats.current_strategy, SyncStrategy::RingAllReduce);
1090 assert_eq!(stats.pending_async_ops, 2);
1091 assert_eq!(stats.step_count, 1000);
1092 }
1093}