1use crate::context::{GpuContext, GpuContextConfig};
7use crate::error::{GpuError, GpuResult};
8use std::collections::HashMap;
9use std::sync::{Arc, Mutex};
10use tracing::{debug, info, warn};
11use wgpu::{Adapter, AdapterInfo, Backend, Backends, BufferUsages, Instance};
12
13#[derive(Debug, Clone)]
15pub struct MultiGpuConfig {
16 pub backends: Backends,
18 pub min_devices: usize,
20 pub max_devices: usize,
22 pub auto_load_balance: bool,
24 pub enable_p2p: bool,
26}
27
28impl Default for MultiGpuConfig {
29 fn default() -> Self {
30 Self {
31 backends: Backends::all(),
32 min_devices: 1,
33 max_devices: 8,
34 auto_load_balance: true,
35 enable_p2p: false,
36 }
37 }
38}
39
40#[derive(Debug, Clone)]
42pub struct GpuDeviceInfo {
43 pub index: usize,
45 pub adapter_info: AdapterInfo,
47 pub backend: Backend,
49 pub vram_bytes: Option<u64>,
51 pub active: bool,
53}
54
55impl GpuDeviceInfo {
56 pub fn description(&self) -> String {
58 format!(
59 "GPU {} : {} ({:?})",
60 self.index, self.adapter_info.name, self.backend
61 )
62 }
63}
64
65pub struct MultiGpuManager {
67 devices: Vec<Arc<GpuContext>>,
69 device_info: Vec<GpuDeviceInfo>,
71 config: MultiGpuConfig,
73 load_state: Arc<Mutex<LoadBalanceState>>,
75}
76
77#[derive(Debug, Clone)]
78struct LoadBalanceState {
79 task_counts: HashMap<usize, usize>,
81 workload: HashMap<usize, f64>,
83}
84
85impl LoadBalanceState {
86 fn new(num_devices: usize) -> Self {
87 let mut task_counts = HashMap::new();
88 let mut workload = HashMap::new();
89
90 for i in 0..num_devices {
91 task_counts.insert(i, 0);
92 workload.insert(i, 0.0);
93 }
94
95 Self {
96 task_counts,
97 workload,
98 }
99 }
100
101 fn select_device(&self) -> usize {
102 self.workload
104 .iter()
105 .min_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
106 .map(|(idx, _)| *idx)
107 .unwrap_or(0)
108 }
109
110 fn add_task(&mut self, device: usize, workload: f64) {
111 *self.task_counts.entry(device).or_insert(0) += 1;
112 *self.workload.entry(device).or_insert(0.0) += workload;
113 }
114
115 fn complete_task(&mut self, device: usize, workload: f64) {
116 if let Some(count) = self.task_counts.get_mut(&device) {
117 *count = count.saturating_sub(1);
118 }
119 if let Some(load) = self.workload.get_mut(&device) {
120 *load = load.max(workload) - workload;
121 }
122 }
123}
124
125impl MultiGpuManager {
126 pub async fn new(config: MultiGpuConfig) -> GpuResult<Self> {
132 info!("Initializing multi-GPU manager");
133
134 let instance = Instance::new(&wgpu::InstanceDescriptor {
135 backends: config.backends,
136 ..Default::default()
137 });
138
139 let adapters = Self::enumerate_adapters(&instance).await;
141
142 if adapters.len() < config.min_devices {
143 return Err(GpuError::no_adapter(format!(
144 "Found {} GPUs, but {} required",
145 adapters.len(),
146 config.min_devices
147 )));
148 }
149
150 let num_devices = adapters.len().min(config.max_devices);
151 info!(
152 "Found {} compatible GPUs, using {}",
153 adapters.len(),
154 num_devices
155 );
156
157 let mut devices = Vec::new();
159 let mut device_info = Vec::new();
160
161 for (index, adapter) in adapters.into_iter().take(num_devices).enumerate() {
162 match Self::create_device_context(adapter, index).await {
163 Ok((context, info)) => {
164 devices.push(Arc::new(context));
165 device_info.push(info);
166 info!(
167 "Initialized: {}",
168 device_info
169 .last()
170 .map(|i| i.description())
171 .unwrap_or_default()
172 );
173 }
174 Err(e) => {
175 warn!("Failed to initialize GPU {}: {}", index, e);
176 }
177 }
178 }
179
180 if devices.len() < config.min_devices {
181 return Err(GpuError::device_request(format!(
182 "Successfully initialized {} GPUs, but {} required",
183 devices.len(),
184 config.min_devices
185 )));
186 }
187
188 let load_state = Arc::new(Mutex::new(LoadBalanceState::new(devices.len())));
189
190 Ok(Self {
191 devices,
192 device_info,
193 config,
194 load_state,
195 })
196 }
197
198 pub fn num_devices(&self) -> usize {
200 self.devices.len()
201 }
202
203 pub fn device(&self, index: usize) -> Option<&Arc<GpuContext>> {
205 self.devices.get(index)
206 }
207
208 pub fn devices(&self) -> &[Arc<GpuContext>] {
210 &self.devices
211 }
212
213 pub fn device_info(&self, index: usize) -> Option<&GpuDeviceInfo> {
215 self.device_info.get(index)
216 }
217
218 pub fn all_device_info(&self) -> &[GpuDeviceInfo] {
220 &self.device_info
221 }
222
223 pub fn select_device(&self) -> usize {
225 if !self.config.auto_load_balance {
226 return 0; }
229
230 self.load_state
231 .lock()
232 .map(|state| state.select_device())
233 .unwrap_or(0)
234 }
235
236 pub fn dispatch<F, T>(&self, workload: f64, f: F) -> GpuResult<T>
238 where
239 F: FnOnce(&GpuContext) -> GpuResult<T>,
240 {
241 let device_idx = self.select_device();
242
243 if let Ok(mut state) = self.load_state.lock() {
244 state.add_task(device_idx, workload);
245 }
246
247 let context = self
248 .devices
249 .get(device_idx)
250 .ok_or_else(|| GpuError::internal("Invalid device index"))?;
251
252 let result = f(context);
253
254 if let Ok(mut state) = self.load_state.lock() {
255 state.complete_task(device_idx, workload);
256 }
257
258 result
259 }
260
261 pub async fn distribute<F, T>(&self, items: Vec<(f64, F)>) -> Vec<GpuResult<T>>
263 where
264 F: FnOnce(&GpuContext) -> GpuResult<T> + Send + 'static,
265 T: Send + 'static,
266 {
267 let mut tasks = Vec::new();
268
269 for (workload, work_fn) in items {
270 let device_idx = self.select_device();
271
272 if let Ok(mut state) = self.load_state.lock() {
273 state.add_task(device_idx, workload);
274 }
275
276 let context = match self.devices.get(device_idx) {
277 Some(ctx) => Arc::clone(ctx),
278 None => continue,
279 };
280
281 let load_state = Arc::clone(&self.load_state);
282
283 let task = tokio::spawn(async move {
284 let result = work_fn(&context);
285
286 if let Ok(mut state) = load_state.lock() {
287 state.complete_task(device_idx, workload);
288 }
289
290 result
291 });
292
293 tasks.push(task);
294 }
295
296 let mut results = Vec::new();
298 for task in tasks {
299 match task.await {
300 Ok(result) => results.push(result),
301 Err(e) => results.push(Err(GpuError::internal(e.to_string()))),
302 }
303 }
304
305 results
306 }
307
308 pub fn load_stats(&self) -> HashMap<usize, (usize, f64)> {
310 self.load_state
311 .lock()
312 .map(|state| {
313 let mut stats = HashMap::new();
314 for i in 0..self.devices.len() {
315 let tasks = *state.task_counts.get(&i).unwrap_or(&0);
316 let workload = *state.workload.get(&i).unwrap_or(&0.0);
317 stats.insert(i, (tasks, workload));
318 }
319 stats
320 })
321 .unwrap_or_default()
322 }
323
324 async fn enumerate_adapters(_instance: &Instance) -> Vec<Adapter> {
325 let mut adapters = Vec::new();
326
327 for backend in &[
329 Backends::VULKAN,
330 Backends::METAL,
331 Backends::DX12,
332 Backends::BROWSER_WEBGPU,
333 ] {
334 let backend_instance = Instance::new(&wgpu::InstanceDescriptor {
335 backends: *backend,
336 ..Default::default()
337 });
338
339 if let Ok(adapter) = backend_instance
340 .request_adapter(&wgpu::RequestAdapterOptions {
341 power_preference: wgpu::PowerPreference::HighPerformance,
342 force_fallback_adapter: false,
343 compatible_surface: None,
344 })
345 .await
346 {
347 adapters.push(adapter);
348 }
349 }
350
351 adapters
352 }
353
354 async fn create_device_context(
355 adapter: Adapter,
356 index: usize,
357 ) -> GpuResult<(GpuContext, GpuDeviceInfo)> {
358 let adapter_info = adapter.get_info();
359 let backend = adapter_info.backend;
360
361 let vram_bytes = Self::estimate_vram(&adapter_info);
363
364 let config = GpuContextConfig::default().with_label(format!("GPU {}", index));
365
366 let context = GpuContext::with_config(config).await?;
367
368 let info = GpuDeviceInfo {
369 index,
370 adapter_info,
371 backend,
372 vram_bytes,
373 active: true,
374 };
375
376 Ok((context, info))
377 }
378
379 fn estimate_vram(adapter_info: &AdapterInfo) -> Option<u64> {
380 match adapter_info.device_type {
382 wgpu::DeviceType::DiscreteGpu => Some(8 * 1024 * 1024 * 1024), wgpu::DeviceType::IntegratedGpu => Some(2 * 1024 * 1024 * 1024), wgpu::DeviceType::VirtualGpu => Some(4 * 1024 * 1024 * 1024), _ => None,
386 }
387 }
388}
389
390pub struct InterGpuTransfer {
392 manager: Arc<MultiGpuManager>,
393}
394
395impl InterGpuTransfer {
396 pub fn new(manager: Arc<MultiGpuManager>) -> Self {
398 Self { manager }
399 }
400
401 pub async fn copy_buffer(
407 &self,
408 src_device: usize,
409 dst_device: usize,
410 data: &[u8],
411 ) -> GpuResult<()> {
412 let _src_ctx = self
413 .manager
414 .device(src_device)
415 .ok_or_else(|| GpuError::invalid_buffer("Invalid source device"))?;
416
417 let dst_ctx = self
418 .manager
419 .device(dst_device)
420 .ok_or_else(|| GpuError::invalid_buffer("Invalid destination device"))?;
421
422 let dst_buffer = dst_ctx.device().create_buffer(&wgpu::BufferDescriptor {
424 label: Some("Inter-GPU Transfer"),
425 size: data.len() as u64,
426 usage: BufferUsages::COPY_DST | BufferUsages::STORAGE,
427 mapped_at_creation: false,
428 });
429
430 dst_ctx.queue().write_buffer(&dst_buffer, 0, data);
432
433 debug!(
434 "Transferred {} bytes from GPU {} to GPU {}",
435 data.len(),
436 src_device,
437 dst_device
438 );
439
440 Ok(())
441 }
442
443 pub async fn broadcast(&self, data: &[u8]) -> GpuResult<()> {
449 for i in 1..self.manager.num_devices() {
450 self.copy_buffer(0, i, data).await?;
451 }
452
453 Ok(())
454 }
455
456 pub async fn gather(&self, dst_device: usize) -> GpuResult<Vec<Vec<u8>>> {
462 let mut results = Vec::new();
463
464 for i in 0..self.manager.num_devices() {
465 if i == dst_device {
466 continue;
467 }
468
469 results.push(Vec::new());
472 }
473
474 Ok(results)
475 }
476}
477
478pub struct GpuAffinityManager {
480 affinity_groups: HashMap<usize, Vec<usize>>,
482}
483
484impl GpuAffinityManager {
485 pub fn new() -> Self {
487 Self {
488 affinity_groups: HashMap::new(),
489 }
490 }
491
492 pub fn set_affinity_group(&mut self, group_id: usize, devices: Vec<usize>) {
494 self.affinity_groups.insert(group_id, devices);
495 }
496
497 pub fn get_affinity_group(&self, device: usize) -> Vec<usize> {
499 for (_, devices) in &self.affinity_groups {
500 if devices.contains(&device) {
501 return devices.clone();
502 }
503 }
504 vec![device]
505 }
506
507 pub fn same_affinity(&self, device_a: usize, device_b: usize) -> bool {
509 let group_a = self.get_affinity_group(device_a);
510 group_a.contains(&device_b)
511 }
512
513 pub fn optimal_device(&self, data_device: usize, available: &[usize]) -> Option<usize> {
515 let group = self.get_affinity_group(data_device);
517
518 for device in available {
519 if group.contains(device) {
520 return Some(*device);
521 }
522 }
523
524 available.first().copied()
526 }
527}
528
529impl Default for GpuAffinityManager {
530 fn default() -> Self {
531 Self::new()
532 }
533}
534
535#[derive(Debug, Clone, Copy, PartialEq, Eq)]
537pub enum DistributionStrategy {
538 RoundRobin,
540 LoadBalanced,
542 DataLocal,
544 SingleDevice,
546}
547
548pub struct WorkDistributor {
550 manager: Arc<MultiGpuManager>,
551 strategy: DistributionStrategy,
552 affinity: GpuAffinityManager,
553}
554
555impl WorkDistributor {
556 pub fn new(manager: Arc<MultiGpuManager>, strategy: DistributionStrategy) -> Self {
558 Self {
559 manager,
560 strategy,
561 affinity: GpuAffinityManager::new(),
562 }
563 }
564
565 pub fn set_affinity_group(&mut self, group_id: usize, devices: Vec<usize>) {
567 self.affinity.set_affinity_group(group_id, devices);
568 }
569
570 pub fn distribute_work<T>(&self, items: Vec<T>) -> Vec<(usize, Vec<T>)> {
572 match self.strategy {
573 DistributionStrategy::RoundRobin => self.round_robin(items),
574 DistributionStrategy::LoadBalanced => self.load_balanced(items),
575 DistributionStrategy::DataLocal => self.data_local(items),
576 DistributionStrategy::SingleDevice => self.single_device(items),
577 }
578 }
579
580 fn round_robin<T>(&self, items: Vec<T>) -> Vec<(usize, Vec<T>)> {
581 let num_devices = self.manager.num_devices();
582 let mut device_items: Vec<Vec<T>> = (0..num_devices).map(|_| Vec::new()).collect();
583
584 for (idx, item) in items.into_iter().enumerate() {
585 device_items[idx % num_devices].push(item);
586 }
587
588 device_items
589 .into_iter()
590 .enumerate()
591 .filter(|(_, items)| !items.is_empty())
592 .collect()
593 }
594
595 fn load_balanced<T>(&self, items: Vec<T>) -> Vec<(usize, Vec<T>)> {
596 let stats = self.manager.load_stats();
597 let num_devices = self.manager.num_devices();
598 let items_len = items.len();
599
600 let mut weights: Vec<f64> = (0..num_devices)
602 .map(|i| {
603 let (_, load) = stats.get(&i).unwrap_or(&(0, 0.0));
604 1.0 / (1.0 + load)
605 })
606 .collect();
607
608 let total: f64 = weights.iter().sum();
610 if total > 0.0 {
611 for w in &mut weights {
612 *w /= total;
613 }
614 }
615
616 let mut device_items: Vec<Vec<T>> = (0..num_devices).map(|_| Vec::new()).collect();
618
619 for (idx, item) in items.into_iter().enumerate() {
620 let target = (idx as f64 + 0.5) / items_len as f64;
621 let mut device = 0;
622 let mut cumulative = 0.0;
623
624 for (dev, weight) in weights.iter().enumerate() {
625 cumulative += weight;
626 if cumulative >= target {
627 device = dev;
628 break;
629 }
630 }
631
632 device_items[device].push(item);
633 }
634
635 device_items
636 .into_iter()
637 .enumerate()
638 .filter(|(_, items)| !items.is_empty())
639 .collect()
640 }
641
642 fn data_local<T>(&self, items: Vec<T>) -> Vec<(usize, Vec<T>)> {
643 self.round_robin(items)
646 }
647
648 fn single_device<T>(&self, items: Vec<T>) -> Vec<(usize, Vec<T>)> {
649 vec![(0, items)]
650 }
651}
652
653#[cfg(test)]
654mod tests {
655 use super::*;
656
657 #[test]
658 fn test_multi_gpu_config() {
659 let config = MultiGpuConfig::default();
660 assert_eq!(config.min_devices, 1);
661 assert_eq!(config.max_devices, 8);
662 assert!(config.auto_load_balance);
663 }
664
665 #[test]
666 fn test_load_balance_state() {
667 let mut state = LoadBalanceState::new(3);
668
669 state.add_task(0, 100.0);
670 state.add_task(1, 50.0);
671 state.add_task(2, 75.0);
672
673 assert_eq!(state.select_device(), 1);
675
676 state.complete_task(1, 50.0);
677 assert_eq!(state.select_device(), 1);
678 }
679
680 #[test]
681 fn test_affinity_manager() {
682 let mut manager = GpuAffinityManager::new();
683
684 manager.set_affinity_group(0, vec![0, 1]);
685 manager.set_affinity_group(1, vec![2, 3]);
686
687 assert!(manager.same_affinity(0, 1));
688 assert!(manager.same_affinity(2, 3));
689 assert!(!manager.same_affinity(0, 2));
690
691 let group = manager.get_affinity_group(0);
692 assert_eq!(group, vec![0, 1]);
693 }
694
695 #[test]
696 fn test_distribution_strategy() {
697 assert_eq!(
698 DistributionStrategy::RoundRobin,
699 DistributionStrategy::RoundRobin
700 );
701 }
702}