Skip to main content

oxigdal_gpu/
multi_gpu.rs

1//! Multi-GPU support for distributed GPU computing.
2//!
3//! This module provides infrastructure for managing multiple GPUs,
4//! distributing work across devices, and handling inter-GPU data transfers.
5
6use crate::context::{GpuContext, GpuContextConfig};
7use crate::error::{GpuError, GpuResult};
8use std::collections::HashMap;
9use std::sync::{Arc, Mutex};
10use tracing::{debug, info, warn};
11use wgpu::{Adapter, AdapterInfo, Backend, Backends, BufferUsages, Instance};
12
13/// Multi-GPU configuration.
14#[derive(Debug, Clone)]
15pub struct MultiGpuConfig {
16    /// Backends to search for GPUs.
17    pub backends: Backends,
18    /// Minimum number of GPUs required.
19    pub min_devices: usize,
20    /// Maximum number of GPUs to use.
21    pub max_devices: usize,
22    /// Enable automatic load balancing.
23    pub auto_load_balance: bool,
24    /// Enable peer-to-peer transfers (if supported).
25    pub enable_p2p: bool,
26}
27
28impl Default for MultiGpuConfig {
29    fn default() -> Self {
30        Self {
31            backends: Backends::all(),
32            min_devices: 1,
33            max_devices: 8,
34            auto_load_balance: true,
35            enable_p2p: false,
36        }
37    }
38}
39
40/// Information about a GPU device.
41#[derive(Debug, Clone)]
42pub struct GpuDeviceInfo {
43    /// Device index.
44    pub index: usize,
45    /// Adapter information.
46    pub adapter_info: AdapterInfo,
47    /// Backend type.
48    pub backend: Backend,
49    /// Estimated VRAM in bytes (if available).
50    pub vram_bytes: Option<u64>,
51    /// Device is currently active.
52    pub active: bool,
53}
54
55impl GpuDeviceInfo {
56    /// Get a human-readable description.
57    pub fn description(&self) -> String {
58        format!(
59            "GPU {} : {} ({:?})",
60            self.index, self.adapter_info.name, self.backend
61        )
62    }
63}
64
65/// Multi-GPU manager for coordinating multiple devices.
66pub struct MultiGpuManager {
67    /// Available GPU contexts.
68    devices: Vec<Arc<GpuContext>>,
69    /// Device information.
70    device_info: Vec<GpuDeviceInfo>,
71    /// Configuration.
72    config: MultiGpuConfig,
73    /// Load balancing state.
74    load_state: Arc<Mutex<LoadBalanceState>>,
75}
76
77#[derive(Debug, Clone)]
78struct LoadBalanceState {
79    /// Number of tasks dispatched to each device.
80    task_counts: HashMap<usize, usize>,
81    /// Estimated workload on each device (arbitrary units).
82    workload: HashMap<usize, f64>,
83}
84
85impl LoadBalanceState {
86    fn new(num_devices: usize) -> Self {
87        let mut task_counts = HashMap::new();
88        let mut workload = HashMap::new();
89
90        for i in 0..num_devices {
91            task_counts.insert(i, 0);
92            workload.insert(i, 0.0);
93        }
94
95        Self {
96            task_counts,
97            workload,
98        }
99    }
100
101    fn select_device(&self) -> usize {
102        // Select device with minimum workload
103        self.workload
104            .iter()
105            .min_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
106            .map(|(idx, _)| *idx)
107            .unwrap_or(0)
108    }
109
110    fn add_task(&mut self, device: usize, workload: f64) {
111        *self.task_counts.entry(device).or_insert(0) += 1;
112        *self.workload.entry(device).or_insert(0.0) += workload;
113    }
114
115    fn complete_task(&mut self, device: usize, workload: f64) {
116        if let Some(count) = self.task_counts.get_mut(&device) {
117            *count = count.saturating_sub(1);
118        }
119        if let Some(load) = self.workload.get_mut(&device) {
120            *load = load.max(workload) - workload;
121        }
122    }
123}
124
125impl MultiGpuManager {
126    /// Create a new multi-GPU manager.
127    ///
128    /// # Errors
129    ///
130    /// Returns an error if minimum number of devices cannot be found.
131    pub async fn new(config: MultiGpuConfig) -> GpuResult<Self> {
132        info!("Initializing multi-GPU manager");
133
134        let instance = Instance::new(&wgpu::InstanceDescriptor {
135            backends: config.backends,
136            ..Default::default()
137        });
138
139        // Enumerate all available adapters
140        let adapters = Self::enumerate_adapters(&instance).await;
141
142        if adapters.len() < config.min_devices {
143            return Err(GpuError::no_adapter(format!(
144                "Found {} GPUs, but {} required",
145                adapters.len(),
146                config.min_devices
147            )));
148        }
149
150        let num_devices = adapters.len().min(config.max_devices);
151        info!(
152            "Found {} compatible GPUs, using {}",
153            adapters.len(),
154            num_devices
155        );
156
157        // Create contexts for each device
158        let mut devices = Vec::new();
159        let mut device_info = Vec::new();
160
161        for (index, adapter) in adapters.into_iter().take(num_devices).enumerate() {
162            match Self::create_device_context(adapter, index).await {
163                Ok((context, info)) => {
164                    devices.push(Arc::new(context));
165                    device_info.push(info);
166                    info!(
167                        "Initialized: {}",
168                        device_info
169                            .last()
170                            .map(|i| i.description())
171                            .unwrap_or_default()
172                    );
173                }
174                Err(e) => {
175                    warn!("Failed to initialize GPU {}: {}", index, e);
176                }
177            }
178        }
179
180        if devices.len() < config.min_devices {
181            return Err(GpuError::device_request(format!(
182                "Successfully initialized {} GPUs, but {} required",
183                devices.len(),
184                config.min_devices
185            )));
186        }
187
188        let load_state = Arc::new(Mutex::new(LoadBalanceState::new(devices.len())));
189
190        Ok(Self {
191            devices,
192            device_info,
193            config,
194            load_state,
195        })
196    }
197
198    /// Get the number of available devices.
199    pub fn num_devices(&self) -> usize {
200        self.devices.len()
201    }
202
203    /// Get a specific device context.
204    pub fn device(&self, index: usize) -> Option<&Arc<GpuContext>> {
205        self.devices.get(index)
206    }
207
208    /// Get all device contexts.
209    pub fn devices(&self) -> &[Arc<GpuContext>] {
210        &self.devices
211    }
212
213    /// Get device information.
214    pub fn device_info(&self, index: usize) -> Option<&GpuDeviceInfo> {
215        self.device_info.get(index)
216    }
217
218    /// Get all device information.
219    pub fn all_device_info(&self) -> &[GpuDeviceInfo] {
220        &self.device_info
221    }
222
223    /// Select a device based on load balancing strategy.
224    pub fn select_device(&self) -> usize {
225        if !self.config.auto_load_balance {
226            // Round-robin without load balancing (use simple counter)
227            return 0; // Simplified for now
228        }
229
230        self.load_state
231            .lock()
232            .map(|state| state.select_device())
233            .unwrap_or(0)
234    }
235
236    /// Dispatch work to a device with load balancing.
237    pub fn dispatch<F, T>(&self, workload: f64, f: F) -> GpuResult<T>
238    where
239        F: FnOnce(&GpuContext) -> GpuResult<T>,
240    {
241        let device_idx = self.select_device();
242
243        if let Ok(mut state) = self.load_state.lock() {
244            state.add_task(device_idx, workload);
245        }
246
247        let context = self
248            .devices
249            .get(device_idx)
250            .ok_or_else(|| GpuError::internal("Invalid device index"))?;
251
252        let result = f(context);
253
254        if let Ok(mut state) = self.load_state.lock() {
255            state.complete_task(device_idx, workload);
256        }
257
258        result
259    }
260
261    /// Distribute work across all devices.
262    pub async fn distribute<F, T>(&self, items: Vec<(f64, F)>) -> Vec<GpuResult<T>>
263    where
264        F: FnOnce(&GpuContext) -> GpuResult<T> + Send + 'static,
265        T: Send + 'static,
266    {
267        let mut tasks = Vec::new();
268
269        for (workload, work_fn) in items {
270            let device_idx = self.select_device();
271
272            if let Ok(mut state) = self.load_state.lock() {
273                state.add_task(device_idx, workload);
274            }
275
276            let context = match self.devices.get(device_idx) {
277                Some(ctx) => Arc::clone(ctx),
278                None => continue,
279            };
280
281            let load_state = Arc::clone(&self.load_state);
282
283            let task = tokio::spawn(async move {
284                let result = work_fn(&context);
285
286                if let Ok(mut state) = load_state.lock() {
287                    state.complete_task(device_idx, workload);
288                }
289
290                result
291            });
292
293            tasks.push(task);
294        }
295
296        // Wait for all tasks to complete
297        let mut results = Vec::new();
298        for task in tasks {
299            match task.await {
300                Ok(result) => results.push(result),
301                Err(e) => results.push(Err(GpuError::internal(e.to_string()))),
302            }
303        }
304
305        results
306    }
307
308    /// Get current load statistics.
309    pub fn load_stats(&self) -> HashMap<usize, (usize, f64)> {
310        self.load_state
311            .lock()
312            .map(|state| {
313                let mut stats = HashMap::new();
314                for i in 0..self.devices.len() {
315                    let tasks = *state.task_counts.get(&i).unwrap_or(&0);
316                    let workload = *state.workload.get(&i).unwrap_or(&0.0);
317                    stats.insert(i, (tasks, workload));
318                }
319                stats
320            })
321            .unwrap_or_default()
322    }
323
324    async fn enumerate_adapters(_instance: &Instance) -> Vec<Adapter> {
325        let mut adapters = Vec::new();
326
327        // Try each backend
328        for backend in &[
329            Backends::VULKAN,
330            Backends::METAL,
331            Backends::DX12,
332            Backends::BROWSER_WEBGPU,
333        ] {
334            let backend_instance = Instance::new(&wgpu::InstanceDescriptor {
335                backends: *backend,
336                ..Default::default()
337            });
338
339            if let Ok(adapter) = backend_instance
340                .request_adapter(&wgpu::RequestAdapterOptions {
341                    power_preference: wgpu::PowerPreference::HighPerformance,
342                    force_fallback_adapter: false,
343                    compatible_surface: None,
344                })
345                .await
346            {
347                adapters.push(adapter);
348            }
349        }
350
351        adapters
352    }
353
354    async fn create_device_context(
355        adapter: Adapter,
356        index: usize,
357    ) -> GpuResult<(GpuContext, GpuDeviceInfo)> {
358        let adapter_info = adapter.get_info();
359        let backend = adapter_info.backend;
360
361        // Estimate VRAM (not directly available in WGPU)
362        let vram_bytes = Self::estimate_vram(&adapter_info);
363
364        let config = GpuContextConfig::default().with_label(format!("GPU {}", index));
365
366        let context = GpuContext::with_config(config).await?;
367
368        let info = GpuDeviceInfo {
369            index,
370            adapter_info,
371            backend,
372            vram_bytes,
373            active: true,
374        };
375
376        Ok((context, info))
377    }
378
379    fn estimate_vram(adapter_info: &AdapterInfo) -> Option<u64> {
380        // This is a rough estimation based on device type
381        match adapter_info.device_type {
382            wgpu::DeviceType::DiscreteGpu => Some(8 * 1024 * 1024 * 1024), // 8 GB
383            wgpu::DeviceType::IntegratedGpu => Some(2 * 1024 * 1024 * 1024), // 2 GB
384            wgpu::DeviceType::VirtualGpu => Some(4 * 1024 * 1024 * 1024),  // 4 GB
385            _ => None,
386        }
387    }
388}
389
390/// Inter-GPU data transfer manager.
391pub struct InterGpuTransfer {
392    manager: Arc<MultiGpuManager>,
393}
394
395impl InterGpuTransfer {
396    /// Create a new inter-GPU transfer manager.
397    pub fn new(manager: Arc<MultiGpuManager>) -> Self {
398        Self { manager }
399    }
400
401    /// Copy data between GPUs.
402    ///
403    /// # Errors
404    ///
405    /// Returns an error if transfer fails or devices are invalid.
406    pub async fn copy_buffer(
407        &self,
408        src_device: usize,
409        dst_device: usize,
410        data: &[u8],
411    ) -> GpuResult<()> {
412        let _src_ctx = self
413            .manager
414            .device(src_device)
415            .ok_or_else(|| GpuError::invalid_buffer("Invalid source device"))?;
416
417        let dst_ctx = self
418            .manager
419            .device(dst_device)
420            .ok_or_else(|| GpuError::invalid_buffer("Invalid destination device"))?;
421
422        // Create buffer on destination device
423        let dst_buffer = dst_ctx.device().create_buffer(&wgpu::BufferDescriptor {
424            label: Some("Inter-GPU Transfer"),
425            size: data.len() as u64,
426            usage: BufferUsages::COPY_DST | BufferUsages::STORAGE,
427            mapped_at_creation: false,
428        });
429
430        // Write data to destination
431        dst_ctx.queue().write_buffer(&dst_buffer, 0, data);
432
433        debug!(
434            "Transferred {} bytes from GPU {} to GPU {}",
435            data.len(),
436            src_device,
437            dst_device
438        );
439
440        Ok(())
441    }
442
443    /// Broadcast data to all GPUs.
444    ///
445    /// # Errors
446    ///
447    /// Returns an error if any transfer fails.
448    pub async fn broadcast(&self, data: &[u8]) -> GpuResult<()> {
449        for i in 1..self.manager.num_devices() {
450            self.copy_buffer(0, i, data).await?;
451        }
452
453        Ok(())
454    }
455
456    /// Gather data from all GPUs to one device.
457    ///
458    /// # Errors
459    ///
460    /// Returns an error if any transfer fails.
461    pub async fn gather(&self, dst_device: usize) -> GpuResult<Vec<Vec<u8>>> {
462        let mut results = Vec::new();
463
464        for i in 0..self.manager.num_devices() {
465            if i == dst_device {
466                continue;
467            }
468
469            // In a real implementation, we would read from the source GPU
470            // For now, this is a placeholder
471            results.push(Vec::new());
472        }
473
474        Ok(results)
475    }
476}
477
478/// GPU affinity manager for NUMA-aware scheduling.
479pub struct GpuAffinityManager {
480    /// Device affinity groups (devices that share memory/PCIe bus).
481    affinity_groups: HashMap<usize, Vec<usize>>,
482}
483
484impl GpuAffinityManager {
485    /// Create a new affinity manager.
486    pub fn new() -> Self {
487        Self {
488            affinity_groups: HashMap::new(),
489        }
490    }
491
492    /// Set devices in the same affinity group.
493    pub fn set_affinity_group(&mut self, group_id: usize, devices: Vec<usize>) {
494        self.affinity_groups.insert(group_id, devices);
495    }
496
497    /// Get devices in the same affinity group.
498    pub fn get_affinity_group(&self, device: usize) -> Vec<usize> {
499        for (_, devices) in &self.affinity_groups {
500            if devices.contains(&device) {
501                return devices.clone();
502            }
503        }
504        vec![device]
505    }
506
507    /// Check if two devices are in the same affinity group.
508    pub fn same_affinity(&self, device_a: usize, device_b: usize) -> bool {
509        let group_a = self.get_affinity_group(device_a);
510        group_a.contains(&device_b)
511    }
512
513    /// Get optimal device for data locality.
514    pub fn optimal_device(&self, data_device: usize, available: &[usize]) -> Option<usize> {
515        // Prefer devices in the same affinity group
516        let group = self.get_affinity_group(data_device);
517
518        for device in available {
519            if group.contains(device) {
520                return Some(*device);
521            }
522        }
523
524        // Fall back to any available device
525        available.first().copied()
526    }
527}
528
529impl Default for GpuAffinityManager {
530    fn default() -> Self {
531        Self::new()
532    }
533}
534
535/// Work distribution strategy for multi-GPU processing.
536#[derive(Debug, Clone, Copy, PartialEq, Eq)]
537pub enum DistributionStrategy {
538    /// Distribute work evenly across all devices.
539    RoundRobin,
540    /// Distribute based on device capabilities.
541    LoadBalanced,
542    /// Distribute based on data locality.
543    DataLocal,
544    /// Use only the fastest device.
545    SingleDevice,
546}
547
548/// Work distributor for multi-GPU task scheduling.
549pub struct WorkDistributor {
550    manager: Arc<MultiGpuManager>,
551    strategy: DistributionStrategy,
552    affinity: GpuAffinityManager,
553}
554
555impl WorkDistributor {
556    /// Create a new work distributor.
557    pub fn new(manager: Arc<MultiGpuManager>, strategy: DistributionStrategy) -> Self {
558        Self {
559            manager,
560            strategy,
561            affinity: GpuAffinityManager::new(),
562        }
563    }
564
565    /// Set affinity group.
566    pub fn set_affinity_group(&mut self, group_id: usize, devices: Vec<usize>) {
567        self.affinity.set_affinity_group(group_id, devices);
568    }
569
570    /// Distribute work items across GPUs.
571    pub fn distribute_work<T>(&self, items: Vec<T>) -> Vec<(usize, Vec<T>)> {
572        match self.strategy {
573            DistributionStrategy::RoundRobin => self.round_robin(items),
574            DistributionStrategy::LoadBalanced => self.load_balanced(items),
575            DistributionStrategy::DataLocal => self.data_local(items),
576            DistributionStrategy::SingleDevice => self.single_device(items),
577        }
578    }
579
580    fn round_robin<T>(&self, items: Vec<T>) -> Vec<(usize, Vec<T>)> {
581        let num_devices = self.manager.num_devices();
582        let mut device_items: Vec<Vec<T>> = (0..num_devices).map(|_| Vec::new()).collect();
583
584        for (idx, item) in items.into_iter().enumerate() {
585            device_items[idx % num_devices].push(item);
586        }
587
588        device_items
589            .into_iter()
590            .enumerate()
591            .filter(|(_, items)| !items.is_empty())
592            .collect()
593    }
594
595    fn load_balanced<T>(&self, items: Vec<T>) -> Vec<(usize, Vec<T>)> {
596        let stats = self.manager.load_stats();
597        let num_devices = self.manager.num_devices();
598        let items_len = items.len();
599
600        // Calculate weights based on inverse of current load
601        let mut weights: Vec<f64> = (0..num_devices)
602            .map(|i| {
603                let (_, load) = stats.get(&i).unwrap_or(&(0, 0.0));
604                1.0 / (1.0 + load)
605            })
606            .collect();
607
608        // Normalize weights
609        let total: f64 = weights.iter().sum();
610        if total > 0.0 {
611            for w in &mut weights {
612                *w /= total;
613            }
614        }
615
616        // Distribute items based on weights
617        let mut device_items: Vec<Vec<T>> = (0..num_devices).map(|_| Vec::new()).collect();
618
619        for (idx, item) in items.into_iter().enumerate() {
620            let target = (idx as f64 + 0.5) / items_len as f64;
621            let mut device = 0;
622            let mut cumulative = 0.0;
623
624            for (dev, weight) in weights.iter().enumerate() {
625                cumulative += weight;
626                if cumulative >= target {
627                    device = dev;
628                    break;
629                }
630            }
631
632            device_items[device].push(item);
633        }
634
635        device_items
636            .into_iter()
637            .enumerate()
638            .filter(|(_, items)| !items.is_empty())
639            .collect()
640    }
641
642    fn data_local<T>(&self, items: Vec<T>) -> Vec<(usize, Vec<T>)> {
643        // For now, fall back to round-robin
644        // In a real implementation, this would consider data locality
645        self.round_robin(items)
646    }
647
648    fn single_device<T>(&self, items: Vec<T>) -> Vec<(usize, Vec<T>)> {
649        vec![(0, items)]
650    }
651}
652
653#[cfg(test)]
654mod tests {
655    use super::*;
656
657    #[test]
658    fn test_multi_gpu_config() {
659        let config = MultiGpuConfig::default();
660        assert_eq!(config.min_devices, 1);
661        assert_eq!(config.max_devices, 8);
662        assert!(config.auto_load_balance);
663    }
664
665    #[test]
666    fn test_load_balance_state() {
667        let mut state = LoadBalanceState::new(3);
668
669        state.add_task(0, 100.0);
670        state.add_task(1, 50.0);
671        state.add_task(2, 75.0);
672
673        // Device 1 should have minimum load
674        assert_eq!(state.select_device(), 1);
675
676        state.complete_task(1, 50.0);
677        assert_eq!(state.select_device(), 1);
678    }
679
680    #[test]
681    fn test_affinity_manager() {
682        let mut manager = GpuAffinityManager::new();
683
684        manager.set_affinity_group(0, vec![0, 1]);
685        manager.set_affinity_group(1, vec![2, 3]);
686
687        assert!(manager.same_affinity(0, 1));
688        assert!(manager.same_affinity(2, 3));
689        assert!(!manager.same_affinity(0, 2));
690
691        let group = manager.get_affinity_group(0);
692        assert_eq!(group, vec![0, 1]);
693    }
694
695    #[test]
696    fn test_distribution_strategy() {
697        assert_eq!(
698            DistributionStrategy::RoundRobin,
699            DistributionStrategy::RoundRobin
700        );
701    }
702}